From 9a930add06d850b6a6b72d44fb6b02d827dc7b70 Mon Sep 17 00:00:00 2001 From: Michael Dwan Date: Wed, 29 Apr 2026 12:47:42 -0600 Subject: [PATCH 1/3] remove dead code: pkg/web, pkg/http, and unused env helpers These packages were confirmed unreachable by deadcode analysis. pkg/web (Replicate API client) and pkg/http (auth transport) lost their callers when the push flow moved to the custom registry pusher. Also removes the orphaned BuildInfo struct and Push function from pkg/docker/push.go, and the WebHost/APIHost env helpers that only those packages used. --- pkg/docker/push.go | 26 -- pkg/env/env.go | 18 -- pkg/env/env_test.go | 6 - pkg/http/client.go | 34 --- pkg/http/client_test.go | 30 -- pkg/http/transport.go | 40 --- pkg/http/transport_test.go | 75 ----- pkg/http/user_agent.go | 24 -- pkg/http/user_agent_test.go | 12 - pkg/web/client.go | 538 ------------------------------------ pkg/web/client_test.go | 204 -------------- 11 files changed, 1007 deletions(-) delete mode 100644 pkg/docker/push.go delete mode 100644 pkg/http/client.go delete mode 100644 pkg/http/client_test.go delete mode 100644 pkg/http/transport.go delete mode 100644 pkg/http/transport_test.go delete mode 100644 pkg/http/user_agent.go delete mode 100644 pkg/http/user_agent_test.go delete mode 100644 pkg/web/client.go delete mode 100644 pkg/web/client_test.go diff --git a/pkg/docker/push.go b/pkg/docker/push.go deleted file mode 100644 index 292e29f56f..0000000000 --- a/pkg/docker/push.go +++ /dev/null @@ -1,26 +0,0 @@ -package docker - -import ( - "context" - "net/http" - "time" - - "github.com/replicate/cog/pkg/docker/command" - "github.com/replicate/cog/pkg/util/console" - "github.com/replicate/cog/pkg/web" -) - -type BuildInfo struct { - BuildTime time.Duration - BuildID string -} - -func Push(ctx context.Context, image string, projectDir string, command command.Command, buildInfo BuildInfo, client *http.Client) error { - webClient := web.NewClient(command, client) - - if err := webClient.PostPushStart(ctx, buildInfo.BuildID, buildInfo.BuildTime); err != nil { - console.Warnf("Failed to send build timings to server: %v", err) - } - - return StandardPush(ctx, image, command) -} diff --git a/pkg/env/env.go b/pkg/env/env.go index f08f29a9d5..2f33761809 100644 --- a/pkg/env/env.go +++ b/pkg/env/env.go @@ -3,8 +3,6 @@ package env import "os" const SchemeEnvVarName = "R8_SCHEME" -const WebHostEnvVarName = "R8_WEB_HOST" -const APIHostEnvVarName = "R8_API_HOST" const PytorchHostEnvVarName = "R8_PYTORCH_HOST" func SchemeFromEnvironment() string { @@ -15,22 +13,6 @@ func SchemeFromEnvironment() string { return scheme } -func WebHostFromEnvironment() string { - host := os.Getenv(WebHostEnvVarName) - if host == "" { - host = "cog.replicate.com" - } - return host -} - -func APIHostFromEnvironment() string { - host := os.Getenv(APIHostEnvVarName) - if host == "" { - host = "api.replicate.com" - } - return host -} - func PytorchHostFromEnvironment() string { host := os.Getenv(PytorchHostEnvVarName) if host == "" { diff --git a/pkg/env/env_test.go b/pkg/env/env_test.go index ff42e3a069..14b3c0cb20 100644 --- a/pkg/env/env_test.go +++ b/pkg/env/env_test.go @@ -11,9 +11,3 @@ func TestSchemeFromEnvironment(t *testing.T) { t.Setenv(SchemeEnvVarName, "myscheme") require.Equal(t, SchemeFromEnvironment(), testScheme) } - -func TestWebHostFromEnvironment(t *testing.T) { - const testHost = "web" - t.Setenv(WebHostEnvVarName, testHost) - require.Equal(t, WebHostFromEnvironment(), testHost) -} diff --git a/pkg/http/client.go b/pkg/http/client.go deleted file mode 100644 index a1ee672d28..0000000000 --- a/pkg/http/client.go +++ /dev/null @@ -1,34 +0,0 @@ -package http - -import ( - "context" - "net/http" - - "github.com/replicate/cog/pkg/docker/command" - "github.com/replicate/cog/pkg/env" - "github.com/replicate/cog/pkg/global" -) - -const UserAgentHeader = "User-Agent" -const BearerHeaderPrefix = "Bearer " - -func ProvideHTTPClient(ctx context.Context, dockerCommand command.Command) (*http.Client, error) { - userInfo, err := dockerCommand.LoadUserInformation(ctx, global.ReplicateRegistryHost) - if err != nil { - return nil, err - } - - client := http.Client{ - Transport: &Transport{ - headers: map[string]string{ - UserAgentHeader: UserAgent(), - "Content-Type": "application/json", - }, - authentication: map[string]string{ - env.WebHostFromEnvironment(): BearerHeaderPrefix + userInfo.Token, - }, - }, - } - - return &client, nil -} diff --git a/pkg/http/client_test.go b/pkg/http/client_test.go deleted file mode 100644 index c68b60a821..0000000000 --- a/pkg/http/client_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package http - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/replicate/cog/pkg/docker/dockertest" -) - -func TestClientDecoratesUserAgent(t *testing.T) { - // Setup mock http server - seenUserAgent := false - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, r.Header.Get(UserAgentHeader), UserAgent()) - seenUserAgent = true - })) - defer server.Close() - - command := dockertest.NewMockCommand() - client, err := ProvideHTTPClient(t.Context(), command) - require.NoError(t, err) - - _, err = client.Get(server.URL) - require.NoError(t, err) - - require.True(t, seenUserAgent) -} diff --git a/pkg/http/transport.go b/pkg/http/transport.go deleted file mode 100644 index 502345b785..0000000000 --- a/pkg/http/transport.go +++ /dev/null @@ -1,40 +0,0 @@ -package http - -import ( - "errors" - "net/http" -) - -const AuthorizationHeader = "Authorization" - -type Transport struct { - headers map[string]string - authentication map[string]string - base http.RoundTripper -} - -func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { - // Write standard headers - for k, v := range t.headers { - if req.Header.Get(k) == "" { - req.Header.Set(k, v) - } - } - - // Write authentication - if req.Header.Get(AuthorizationHeader) == "" { - authorisation, ok := t.authentication[req.URL.Host] - if ok { - if authorisation == BearerHeaderPrefix { - return nil, errors.New("No token supplied for HTTP authorization. Have you run 'cog login'?") - } - req.Header.Set(AuthorizationHeader, authorisation) - } - } - - base := t.base - if base == nil { - base = http.DefaultTransport - } - return base.RoundTrip(req) -} diff --git a/pkg/http/transport_test.go b/pkg/http/transport_test.go deleted file mode 100644 index 49697e9d9c..0000000000 --- a/pkg/http/transport_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package http - -import ( - "net/http" - "net/http/httptest" - "net/url" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestTransportAddsHeaders(t *testing.T) { - // Setup mock http server - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - const testHeader = "X-Test-Header" - const testValue = "TestValue" - transport := Transport{ - headers: map[string]string{ - testHeader: testValue, - }, - } - req, err := http.NewRequest("GET", server.URL, nil) - require.NoError(t, err) - resp, err := transport.RoundTrip(req) - require.NoError(t, err) - require.Equal(t, resp.Request.Header.Get(testHeader), testValue) -} - -func TestTransportOnlyAddsHeaderIfMissing(t *testing.T) { - // Setup mock http server - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - const testHeader = "X-Test-Header" - const testValue = "TestValue" - transport := Transport{ - headers: map[string]string{ - testHeader: testValue, - }, - } - const expectedValue = "ExpectedValue" - req, err := http.NewRequest("GET", server.URL, nil) - req.Header.Set(testHeader, expectedValue) - require.NoError(t, err) - resp, err := transport.RoundTrip(req) - require.NoError(t, err) - require.Equal(t, resp.Request.Header.Get(testHeader), expectedValue) -} - -func TestTransportSendsErrorWithMissingToken(t *testing.T) { - // Setup mock http server - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - u, err := url.Parse(server.URL) - require.NoError(t, err) - - transport := Transport{ - authentication: map[string]string{ - u.Host: BearerHeaderPrefix + "", - }, - } - req, err := http.NewRequest("GET", server.URL, nil) - require.NoError(t, err) - resp, err := transport.RoundTrip(req) - require.Error(t, err) - require.Nil(t, resp) -} diff --git a/pkg/http/user_agent.go b/pkg/http/user_agent.go deleted file mode 100644 index 558aca28bb..0000000000 --- a/pkg/http/user_agent.go +++ /dev/null @@ -1,24 +0,0 @@ -package http - -import ( - "fmt" - "runtime" - - "github.com/replicate/cog/pkg/global" -) - -func UserAgent() string { - var platform string - switch runtime.GOOS { - case "linux": - platform = "Linux" - case "windows": - platform = "Windows" - case "darwin": - platform = "macOS" - default: - platform = runtime.GOOS - } - - return fmt.Sprintf("Cog/%s (%s)", global.Version, platform) -} diff --git a/pkg/http/user_agent_test.go b/pkg/http/user_agent_test.go deleted file mode 100644 index 07e91449e2..0000000000 --- a/pkg/http/user_agent_test.go +++ /dev/null @@ -1,12 +0,0 @@ -package http - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestUserAgent(t *testing.T) { - require.True(t, strings.HasPrefix(UserAgent(), "Cog/")) -} diff --git a/pkg/web/client.go b/pkg/web/client.go deleted file mode 100644 index 2dfc6f61be..0000000000 --- a/pkg/web/client.go +++ /dev/null @@ -1,538 +0,0 @@ -package web - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "net/http" - "net/url" - "os" - "strconv" - "strings" - "sync" - "time" - - "github.com/docker/docker/api/types/image" - "github.com/replicate/go/types" - "golang.org/x/sync/errgroup" - - "github.com/replicate/cog/pkg/config" - "github.com/replicate/cog/pkg/docker/command" - "github.com/replicate/cog/pkg/env" - r8_errors "github.com/replicate/cog/pkg/errors" - "github.com/replicate/cog/pkg/global" - "github.com/replicate/cog/pkg/util" - "github.com/replicate/cog/pkg/util/console" -) - -const ( - pushStartURLPath = "/api/models/push-start" - startChallengeURLPath = "/api/models/file-challenge" -) - -var ( - ErrorBadResponseNewVersionEndpoint = errors.New("Bad response from new version endpoint") - ErrorBadResponsePushStartEndpoint = errors.New("Bad response from push start endpoint") - ErrorBadResponseInitiateChallengeEndpoint = errors.New("Bad response from start file challenge endpoint") - ErrorNoSuchDigest = errors.New("No digest submitted matches the digest requested") -) - -type Client struct { - dockerCommand command.Command - client *http.Client -} - -type File struct { - Path string `json:"path"` - Digest string `json:"digest"` - Size int64 `json:"size"` -} - -type Env struct { - CogGpu string `json:"COG_GPU"` - CogPredictTypeStub string `json:"COG_PREDICT_TYPE_STUB"` - CogTrainTypeStub string `json:"COG_TRAIN_TYPE_STUB"` - CogPredictCodeStrip string `json:"COG_PREDICT_CODE_STRIP"` - CogTrainCodeStrip string `json:"COG_TRAIN_CODE_STRIP"` - R8CogVersion string `json:"R8_COG_VERSION"` - R8CudaVersion string `json:"R8_CUDA_VERSION"` - R8CudnnVersion string `json:"R8_CUDNN_VERSION"` - R8PythonVersion string `json:"R8_PYTHON_VERSION"` - R8TorchVersion string `json:"R8_TORCH_VERSION"` -} - -type RuntimeConfig struct { - Weights []File `json:"weights"` - Files []File `json:"files"` - Env Env `json:"env"` -} - -type Version struct { - Annotations map[string]string `json:"annotations"` - CogConfig config.Config `json:"cog_config"` - CogVersion string `json:"cog_version"` - OpenAPISchema map[string]any `json:"openapi_schema"` - RuntimeConfig RuntimeConfig `json:"runtime_config"` - Virtual bool `json:"virtual"` - PushID string `json:"push_id"` - Challenges []FileChallengeAnswer `json:"file_challenges"` -} - -type FileChallengeRequest struct { - Digest string `json:"digest"` - FileType string `json:"file_type"` -} - -type FileChallenge struct { - Salt string `json:"salt"` - Start int `json:"byte_start"` - End int `json:"byte_end"` - Digest string `json:"digest"` - ID string `json:"challenge_id"` -} - -type FileChallengeAnswer struct { - Digest string `json:"digest"` - Hash string `json:"hash"` - ChallengeID string `json:"challenge_id"` -} - -type VersionError struct { - Detail string `json:"detail"` - Pointer string `json:"pointer"` -} - -type VersionErrors struct { - Detail string `json:"detail"` - Errors []VersionError `json:"errors"` - Status int `json:"status"` - Title string `json:"title"` -} - -type VersionCreate struct { - Version string `json:"version"` -} - -type CogKey struct { - Key string `json:"key"` - ExpiresAt string `json:"expires_at"` -} - -type Keys struct { - Cog CogKey `json:"cog"` -} - -type TokenData struct { - Keys Keys `json:"keys"` -} - -func NewClient(dockerCommand command.Command, client *http.Client) *Client { - return &Client{ - dockerCommand: dockerCommand, - client: client, - } -} - -func (c *Client) PostPushStart(ctx context.Context, pushID string, buildTime time.Duration) error { - jsonBody := map[string]any{ - "push_id": pushID, - "build_duration": types.Duration(buildTime).String(), - "push_start_time": time.Now().UTC(), - } - - jsonData, err := json.Marshal(jsonBody) - if err != nil { - return util.WrapError(err, "failed to marshal JSON for build start") - } - - url := webBaseURL() - url.Path = pushStartURLPath - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url.String(), bytes.NewReader(jsonData)) - if err != nil { - return err - } - - resp, err := c.client.Do(req) //nolint:gosec // G704: URL from configured endpoint - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return util.WrapError(ErrorBadResponsePushStartEndpoint, strconv.Itoa(resp.StatusCode)) - } - - return nil -} - -func (c *Client) PostNewVersion(ctx context.Context, image string, weights []File, files []File, fileChallenges []FileChallengeAnswer) error { - version, err := c.versionFromManifest(ctx, image, weights, files, fileChallenges) - if err != nil { - return util.WrapError(err, "failed to build new version from manifest") - } - - jsonData, err := json.Marshal(version) - if err != nil { - return util.WrapError(err, "failed to marshal JSON for new version") - } - - versionUrl, err := newVersionURL(image) - if err != nil { - return err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, versionUrl.String(), bytes.NewReader(jsonData)) - if err != nil { - return err - } - - resp, err := c.client.Do(req) //nolint:gosec // G704: URL from configured endpoint - if err != nil { - return err - } - defer resp.Body.Close() - decoder := json.NewDecoder(resp.Body) - - if resp.StatusCode != http.StatusCreated { - if resp.StatusCode == http.StatusBadRequest { - var versionErrors VersionErrors - err = decoder.Decode(&versionErrors) - if err != nil { - return err - } - return errors.New(versionErrors.Detail) - } - return util.WrapError(ErrorBadResponseNewVersionEndpoint, strconv.Itoa(resp.StatusCode)) - } - - var versionCreate VersionCreate - err = decoder.Decode(&versionCreate) - if err != nil { - return err - } - console.Infof("New Version: %s", versionCreate.Version) - - return nil -} - -func (c *Client) FetchAPIToken(ctx context.Context, entity string) (string, error) { - tokenUrl := tokenURL(entity) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenUrl.String(), nil) - if err != nil { - return "", err - } - - tokenResp, err := c.client.Do(req) //nolint:gosec // G704: URL from configured endpoint - if err != nil { - return "", err - } - defer tokenResp.Body.Close() - - if tokenResp.StatusCode != http.StatusOK { - return "", fmt.Errorf("Bad response: %s attempting to exchange tokens", strconv.Itoa(tokenResp.StatusCode)) - } - - var tokenData TokenData - err = json.NewDecoder(tokenResp.Body).Decode(&tokenData) - if err != nil { - return "", err - } - - return tokenData.Keys.Cog.Key, nil -} - -func (c *Client) versionFromManifest(ctx context.Context, image string, weights []File, files []File, fileChallenges []FileChallengeAnswer) (*Version, error) { - manifest, err := c.dockerCommand.Inspect(ctx, image) - if err != nil { - return nil, util.WrapError(err, "failed to inspect docker image") - } - - cogConfig, err := readCogConfig(manifest) - if err != nil { - return nil, err - } - - var openAPISchema map[string]any - err = json.Unmarshal([]byte(manifest.Config.Labels[command.CogOpenAPISchemaLabelKey]), &openAPISchema) - if err != nil { - return nil, util.WrapError(err, "failed to get OpenAPI schema from docker image") - } - - predictCode, err := stripCodeFromStub(cogConfig, true) - if err != nil { - return nil, err - } - trainCode, err := stripCodeFromStub(cogConfig, false) - if err != nil { - return nil, err - } - - var cogGPU int - if cogConfig.Build.GPU { - cogGPU = 1 - } - - cogVersion := "" - torchVersion := "" - cudaVersion := "" - cudnnVersion := "" - pythonVersion := "" - for _, env := range manifest.Config.Env { - envName, envValue, found := strings.Cut(env, "=") - if !found { - continue - } - switch envName { - case command.R8CogVersionEnvVarName: - cogVersion = envValue - case command.R8TorchVersionEnvVarName: - torchVersion = envValue - case command.R8CudaVersionEnvVarName: - cudaVersion = envValue - case command.R8CudnnVersionEnvVarName: - cudnnVersion = envValue - case command.R8PythonVersionEnvVarName: - pythonVersion = envValue - } - } - - env := Env{ - CogGpu: strconv.Itoa(cogGPU), - CogPredictTypeStub: cogConfig.Predict, - CogTrainTypeStub: cogConfig.Train, - CogPredictCodeStrip: predictCode, - CogTrainCodeStrip: trainCode, - R8CogVersion: cogVersion, - R8CudaVersion: cudaVersion, - R8CudnnVersion: cudnnVersion, - R8PythonVersion: pythonVersion, - R8TorchVersion: torchVersion, - } - - prefixedFiles := make([]File, len(files)) - - for i, file := range files { - prefixedFiles[i] = File{ - Path: file.Path, - Digest: "sha256:" + file.Digest, - Size: file.Size, - } - } - - prefixedWeights := make([]File, len(weights)) - - for i, file := range weights { - prefixedWeights[i] = File{ - Path: file.Path, - Digest: "sha256:" + file.Digest, - Size: file.Size, - } - } - - // Digests should match whatever digest we are sending in as the - // runtime config digests - for i, challenge := range fileChallenges { - fileChallenges[i] = FileChallengeAnswer{ - Digest: fmt.Sprintf("sha256:%s", challenge.Digest), - Hash: challenge.Hash, - ChallengeID: challenge.ChallengeID, - } - } - - runtimeConfig := RuntimeConfig{ - Weights: prefixedWeights, - Files: prefixedFiles, - Env: env, - } - - version := Version{ - Annotations: manifest.Config.Labels, - CogConfig: *cogConfig, - CogVersion: manifest.Config.Labels[command.CogVersionLabelKey], - OpenAPISchema: openAPISchema, - RuntimeConfig: runtimeConfig, - Virtual: true, - Challenges: fileChallenges, - } - - if pushID, ok := manifest.Config.Labels["run.cog.push_id"]; ok { - version.PushID = pushID - } - - return &version, nil -} - -func (c *Client) InitiateAndDoFileChallenge(ctx context.Context, weights []File, files []File) ([]FileChallengeAnswer, error) { - var challengeAnswers []FileChallengeAnswer - var mu sync.Mutex - - var wg errgroup.Group - for _, item := range files { - wg.Go(func() error { - answer, err := c.doSingleFileChallenge(ctx, item, "files") - if err != nil { - return util.WrapError(err, fmt.Sprintf("do file challenge for digest %s", item.Digest)) - } - mu.Lock() - challengeAnswers = append(challengeAnswers, answer) - mu.Unlock() - return nil - }) - } - for _, item := range weights { - wg.Go(func() error { - answer, err := c.doSingleFileChallenge(ctx, item, "weights") - if err != nil { - return util.WrapError(err, fmt.Sprintf("do file challenge for digest %s", item.Digest)) - } - mu.Lock() - challengeAnswers = append(challengeAnswers, answer) - mu.Unlock() - return nil - }) - } - if err := wg.Wait(); err != nil { - return nil, util.WrapError(err, "do file challenges") - } - - return challengeAnswers, nil -} - -// doSingleFileChallenge does a single file challenge. This is expected to be called in a goroutine. -func (c *Client) doSingleFileChallenge(ctx context.Context, file File, fileType string) (FileChallengeAnswer, error) { - initiateChallengePath := webBaseURL() - initiateChallengePath.Path = startChallengeURLPath - - answer := FileChallengeAnswer{} - - jsonData, err := json.Marshal(FileChallengeRequest{ - Digest: file.Digest, - FileType: fileType, - }) - - if err != nil { - return answer, util.WrapError(err, "encode request JSON") - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, initiateChallengePath.String(), bytes.NewReader(jsonData)) - if err != nil { - return answer, util.WrapError(err, "build HTTP request") - } - resp, err := c.client.Do(req) //nolint:gosec // G704: URL from configured endpoint - if err != nil { - return answer, util.WrapError(err, "do HTTP request") - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return answer, util.WrapError(ErrorBadResponseInitiateChallengeEndpoint, strconv.Itoa(resp.StatusCode)) - } - - var challenge FileChallenge - err = json.NewDecoder(resp.Body).Decode(&challenge) - if err != nil { - return answer, util.WrapError(err, "decode response body") - } - - ans, err := util.SHA256HashFileWithSaltAndRange(file.Path, challenge.Start, challenge.End, challenge.Salt) - if err != nil { - return answer, util.WrapError(err, "hash file") - } - return FileChallengeAnswer{ - Digest: file.Digest, - Hash: ans, - ChallengeID: challenge.ID, - }, nil -} - -func newVersionURL(image string) (url.URL, error) { - imageComponents := strings.Split(image, "/") - newVersionUrl := webBaseURL() - if len(imageComponents) != 3 || imageComponents[0] != global.ReplicateRegistryHost { - return newVersionUrl, r8_errors.ErrorBadRegistryURL - } - newVersionUrl.Path = strings.Join([]string{"", "api", "models", imageComponents[1], imageComponents[2], "versions"}, "/") - return newVersionUrl, nil -} - -func tokenURL(entity string) url.URL { - newVersionUrl := webBaseURL() - newVersionUrl.Path = strings.Join([]string{"", "api", "token", entity}, "/") - return newVersionUrl -} - -func webBaseURL() url.URL { - return url.URL{ - Scheme: env.SchemeFromEnvironment(), - Host: env.WebHostFromEnvironment(), - } -} - -func codeFileName(cogConfig *config.Config, isPredict bool) (string, error) { - var stubComponents []string - if isPredict { - if cogConfig.Predict == "" { - return "", nil - } - stubComponents = strings.Split(cogConfig.Predict, ":") - } else { - if cogConfig.Train == "" { - return "", nil - } - stubComponents = strings.Split(cogConfig.Train, ":") - } - - if len(stubComponents) < 2 { - return "", errors.New("Code stub components has less than 2 entries.") - } - - return stubComponents[0], nil -} - -func readCode(cogConfig *config.Config, isPredict bool) (string, string, error) { - codeFile, err := codeFileName(cogConfig, isPredict) - if err != nil { - return "", codeFile, err - } - if codeFile == "" { - return "", "", nil - } - - b, err := os.ReadFile(codeFile) - if err != nil { - return "", codeFile, err - } - - return string(b), codeFile, nil -} - -func stripCodeFromStub(cogConfig *config.Config, isPredict bool) (string, error) { - // TODO: We should attempt to strip the code here, in python this is done like so: - // from cog.code_xforms import strip_model_source_code - // code = strip_model_source_code( - // util.read_file(os.path.join(fs, 'src', base_file)), - // [base_class], - // ['predict', 'train'], - // ) - // Currently the behavior of the code strip attempts to strip, and if it can't it - // loads the whole file in. Here we just load the whole file in. - // We should figure out a way to call cog python from here to fulfill this. - // It could be a good idea to do this in the layer functions where we do pip freeze - // et al. - - code, _, err := readCode(cogConfig, isPredict) - return code, err -} - -func readCogConfig(manifest *image.InspectResponse) (*config.Config, error) { - var cogConfig config.Config - err := json.Unmarshal([]byte(manifest.Config.Labels[command.CogConfigLabelKey]), &cogConfig) - if err != nil { - return nil, util.WrapError(err, "failed to get cog config from docker image") - } - - return &cogConfig, nil -} diff --git a/pkg/web/client_test.go b/pkg/web/client_test.go deleted file mode 100644 index 1c77817d59..0000000000 --- a/pkg/web/client_test.go +++ /dev/null @@ -1,204 +0,0 @@ -package web - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "net/url" - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/replicate/cog/pkg/config" - "github.com/replicate/cog/pkg/docker/dockertest" - "github.com/replicate/cog/pkg/env" -) - -func TestPostNewVersion(t *testing.T) { - // Setup mock http server - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - output := "{\"version\":\"user/test:53c740f17ce88a61c3da5b0c20e48fd48e2da537c3a1276dec63ab11fbad6bcb\"}" - w.WriteHeader(http.StatusCreated) - w.Write([]byte(output)) - })) - defer server.Close() - url, err := url.Parse(server.URL) - require.NoError(t, err) - t.Setenv(env.SchemeEnvVarName, url.Scheme) - t.Setenv(env.WebHostEnvVarName, url.Host) - - dir := t.TempDir() - - // Create mock predict - predictPyPath := filepath.Join(dir, "predict.py") - handle, err := os.Create(predictPyPath) - require.NoError(t, err) - handle.WriteString("import cog") - dockertest.MockCogConfig = "{\"build\":{\"python_version\":\"3.12\",\"python_packages\":[\"torch==2.5.0\",\"beautifulsoup4==4.12.3\"],\"system_packages\":[\"git\"]},\"image\":\"test\",\"predict\":\"" + predictPyPath + ":Predictor\"}" - - // Setup mock command - command := dockertest.NewMockCommand() - - client := NewClient(command, http.DefaultClient) - err = client.PostNewVersion(t.Context(), "r8.im/user/test", []File{}, []File{}, nil) - require.NoError(t, err) -} - -func TestVersionFromManifest(t *testing.T) { - // Setup mock command - command := dockertest.NewMockCommand() - - // Create mock predict - dir := t.TempDir() - predictPyPath := filepath.Join(dir, "predict.py") - handle, err := os.Create(predictPyPath) - require.NoError(t, err) - handle.WriteString("import cog") - dockertest.MockCogConfig = "{\"build\":{\"python_version\":\"3.12\",\"python_packages\":[\"torch==2.5.0\",\"beautifulsoup4==4.12.3\"],\"system_packages\":[\"git\"]},\"image\":\"test\",\"predict\":\"" + predictPyPath + ":Predictor\"}" - dockertest.MockOpenAPISchema = "{\"test\": true}" - - client := NewClient(command, http.DefaultClient) - version, err := client.versionFromManifest(t.Context(), "r8.im/user/test", []File{}, []File{}, nil) - require.NoError(t, err) - - var openAPISchema map[string]any - err = json.Unmarshal([]byte(dockertest.MockOpenAPISchema), &openAPISchema) - require.NoError(t, err) - - var cogConfig config.Config - err = json.Unmarshal([]byte(dockertest.MockCogConfig), &cogConfig) - require.NoError(t, err) - - require.Equal(t, openAPISchema, version.OpenAPISchema) - require.Equal(t, cogConfig, version.CogConfig) -} - -func TestVersionURLErrorWithoutR8IMPrefix(t *testing.T) { - _, err := newVersionURL("docker.com/thing/thing") - require.Error(t, err) -} - -func TestVersionURLErrorWithout3Components(t *testing.T) { - _, err := newVersionURL("username/test") - require.Error(t, err) -} - -func TestDoFileChallenge(t *testing.T) { - dir := t.TempDir() - path := filepath.Join(dir, "test.tmp") - d1 := []byte("hello\nreplicate\nhello\n") - err := os.WriteFile(path, d1, 0o644) - require.NoError(t, err) - - path2 := filepath.Join(dir, "test2.tmp") - d2 := []byte("hello\nreplicate\nhello\n") - err = os.WriteFile(path2, d2, 0o644) - require.NoError(t, err) - - files := []File{ - { - Path: path, - Digest: "abc", - Size: 22, - }, - } - weights := []File{ - { - Path: path, - Digest: "def", - Size: 22, - }, - } - - abcChallenge := FileChallenge{ - ID: "abc", - Digest: "abc", - Start: 0, - End: 6, - Salt: "go\n", - } - - defChallenge := FileChallenge{ - ID: "def", - Digest: "def", - Start: 16, - End: 22, - Salt: "go\n", - } - - // Setup mock http server - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - var challengeRequest FileChallengeRequest - // Ignore errors - make sure the test is set up correctly - json.NewDecoder(r.Body).Decode(&challengeRequest) - if challengeRequest.Digest == "abc" { - body, _ := json.Marshal(abcChallenge) - w.Write(body) - } else { - body, _ := json.Marshal(defChallenge) - w.Write(body) - } - })) - defer server.Close() - url, err := url.Parse(server.URL) - require.NoError(t, err) - t.Setenv(env.SchemeEnvVarName, url.Scheme) - t.Setenv(env.WebHostEnvVarName, url.Host) - - // Setup mock command - command := dockertest.NewMockCommand() - client := NewClient(command, http.DefaultClient) - response, err := client.InitiateAndDoFileChallenge(t.Context(), weights, files) - require.NoError(t, err) - assert.ElementsMatch(t, response, []FileChallengeAnswer{ - { - ChallengeID: "abc", - Digest: "abc", - Hash: "43d250d92b5dbb47f75208de8e9a9a321d23e85eed0dc3d5dfa83bc3cc5aa68c", - }, - { - ChallengeID: "def", - Digest: "def", - Hash: "43d250d92b5dbb47f75208de8e9a9a321d23e85eed0dc3d5dfa83bc3cc5aa68c", - }, - }) -} - -func TestFetchToken(t *testing.T) { - // Setup mock http server - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/api/token/user": - // Mock token exchange response - tokenResponse := `{ - "keys": { - "cog": { - "key": "test-api-token", - "expires_at": "2024-12-31T23:59:59Z" - } - } - }` - w.WriteHeader(http.StatusOK) - w.Write([]byte(tokenResponse)) - default: - w.WriteHeader(http.StatusNotFound) - } - })) - defer server.Close() - url, err := url.Parse(server.URL) - require.NoError(t, err) - t.Setenv(env.SchemeEnvVarName, url.Scheme) - t.Setenv(env.WebHostEnvVarName, url.Host) - - // Setup mock command - command := dockertest.NewMockCommand() - - client := NewClient(command, http.DefaultClient) - token, err := client.FetchAPIToken(t.Context(), "user") - require.NoError(t, err) - require.Equal(t, "test-api-token", token) -} From 20ed5c232197b5463de36c2194b28f6490475f66 Mon Sep 17 00:00:00 2001 From: Michael Dwan Date: Thu, 30 Apr 2026 09:41:49 -0600 Subject: [PATCH 2/3] remove dead code: fast-build leftovers, old docker CLI client, config helpers, console/shell utils, predict inputs --- go.mod | 1 - go.sum | 2 - pkg/config/build_options.go | 8 -- pkg/config/config.go | 9 --- pkg/config/config_test.go | 23 +++--- pkg/config/parse.go | 39 --------- pkg/config/validate.go | 14 ---- pkg/docker/docker_client_test.go | 2 - pkg/docker/env.go | 13 --- pkg/docker/errors.go | 32 -------- pkg/dockerfile/version_check.go | 11 --- pkg/dockerignore/dockerignore.go | 26 ------ pkg/dockerignore/dockerignore_test.go | 56 ------------- pkg/predict/input.go | 17 ---- pkg/requirements/requirements.go | 44 ---------- pkg/requirements/requirements_test.go | 18 ----- pkg/util/console/console.go | 11 --- pkg/util/console/formatting.go | 11 --- pkg/util/console/global.go | 10 --- pkg/util/console/interactive.go | 111 -------------------------- pkg/util/console/levels.go | 55 +------------ pkg/util/console/term.go | 15 ---- pkg/util/errors.go | 12 --- pkg/util/files/files.go | 55 ------------- pkg/util/files/files_test.go | 12 --- pkg/util/hash.go | 76 ------------------ pkg/util/hash_test.go | 43 ---------- pkg/util/ringbuffer.go | 60 -------------- pkg/util/shell/net.go | 57 ------------- pkg/util/shell/pipes.go | 28 ------- 30 files changed, 10 insertions(+), 861 deletions(-) delete mode 100644 pkg/docker/env.go delete mode 100644 pkg/dockerignore/dockerignore_test.go delete mode 100644 pkg/util/console/formatting.go delete mode 100644 pkg/util/console/interactive.go delete mode 100644 pkg/util/errors.go delete mode 100644 pkg/util/hash.go delete mode 100644 pkg/util/hash_test.go delete mode 100644 pkg/util/ringbuffer.go delete mode 100644 pkg/util/shell/net.go delete mode 100644 pkg/util/shell/pipes.go diff --git a/go.mod b/go.mod index 88fcdcc975..fbb116234e 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,6 @@ require ( github.com/tonistiigi/go-csvvalue v0.0.0-20240814133006-030d3b2625d0 github.com/vincent-petithory/dataurl v1.0.0 github.com/xeipuuv/gojsonschema v1.2.0 - github.com/xeonx/timeago v1.0.0-rc5 go.yaml.in/yaml/v4 v4.0.0-rc.4 golang.org/x/crypto v0.50.0 golang.org/x/exp v0.0.0-20250911091902-df9299821621 diff --git a/go.sum b/go.sum index 40e1c7bf50..47924e59b6 100644 --- a/go.sum +++ b/go.sum @@ -292,8 +292,6 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHo github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= -github.com/xeonx/timeago v1.0.0-rc5 h1:pwcQGpaH3eLfPtXeyPA4DmHWjoQt0Ea7/++FwpxqLxg= -github.com/xeonx/timeago v1.0.0-rc5/go.mod h1:qDLrYEFynLO7y5Ho7w3GwgtYgpy5UfhcXIIQvMKVDkA= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= diff --git a/pkg/config/build_options.go b/pkg/config/build_options.go index 21e63cc2a5..7cc0309b8f 100644 --- a/pkg/config/build_options.go +++ b/pkg/config/build_options.go @@ -14,11 +14,3 @@ type BuildOptions struct { // If empty, inline caching is used instead of local cache. XCachePath string } - -// DefaultBuildOptions returns BuildOptions with sensible defaults. -func DefaultBuildOptions() BuildOptions { - return BuildOptions{ - SourceEpochTimestamp: -1, - XCachePath: "", - } -} diff --git a/pkg/config/config.go b/pkg/config/config.go index ba1fe8b681..3846288c9c 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -85,15 +85,6 @@ type Config struct { parsedEnvironment map[string]string } -func defaultConfig() *Config { - return &Config{ - Build: &Build{ - GPU: false, - PythonVersion: "3.13", - }, - } -} - func (r *RunItem) UnmarshalYAML(unmarshal func(any) error) error { var commandOrMap any if err := unmarshal(&commandOrMap); err != nil { diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index ba99b89d0a..55f7bb5e00 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -589,19 +589,14 @@ torch==1.12.1` func TestBlankBuild(t *testing.T) { // Naively, this turns into nil, so make sure it's a real build object - // Write a temp file - dir := t.TempDir() - configPath := path.Join(dir, "cog.yaml") - err := os.WriteFile(configPath, []byte(`build:`), 0o644) - require.NoError(t, err) - - cfgFile, err := parseFile(configPath) + cfgFile, err := parseBytes([]byte(`build:`)) require.NoError(t, err) // Note: `build:` by itself in YAML parses to Build: nil (empty map becomes nil pointer) // The completion step should create a default Build config, err := configFileToConfig(cfgFile) require.NoError(t, err) + dir := t.TempDir() require.NoError(t, config.Complete(dir)) require.NotNil(t, config.Build) require.Equal(t, false, config.Build.GPU) @@ -637,17 +632,17 @@ build: run: - command: "echo 'Hello, World!'" ` - dir := t.TempDir() - configPath := path.Join(dir, "cog.yaml") - err := os.WriteFile(configPath, []byte(yamlString), 0o644) - require.NoError(t, err) - - _, err = parseFile(configPath) + _, err := parseBytes([]byte(yamlString)) require.NoError(t, err) } func TestConfigMarshal(t *testing.T) { - cfg := defaultConfig() + cfg := &Config{ + Build: &Build{ + GPU: false, + PythonVersion: "3.13", + }, + } data, err := yaml.Marshal(cfg) require.NoError(t, err) // yaml v4 uses 4-space indentation by default diff --git a/pkg/config/parse.go b/pkg/config/parse.go index 1826b56f15..c8700183ee 100644 --- a/pkg/config/parse.go +++ b/pkg/config/parse.go @@ -3,12 +3,8 @@ package config import ( "fmt" "io" - "os" - "path/filepath" "go.yaml.in/yaml/v4" - - "github.com/replicate/cog/pkg/util/files" ) // parse reads and parses YAML content from an io.Reader into a configFile. @@ -23,41 +19,6 @@ func parse(r io.Reader) (*configFile, error) { return parseBytes(contents) } -// parseFile reads and parses a cog.yaml file into a configFile. -// This only does YAML parsing - no validation or defaults. -// Returns ParseError if the file cannot be read or parsed. -func parseFile(filename string) (*configFile, error) { - exists, err := files.Exists(filename) - if err != nil { - return nil, &ParseError{Filename: filename, Err: err} - } - - if !exists { - return nil, &ParseError{ - Filename: filename, - Err: fmt.Errorf("%s does not exist in %s", filepath.Base(filename), filepath.Dir(filename)), - } - } - - f, err := os.Open(filename) - if err != nil { - return nil, &ParseError{Filename: filename, Err: err} - } - defer f.Close() - - cfg, err := parse(f) - if err != nil { - // Add filename context to the error - if parseErr, ok := err.(*ParseError); ok { - parseErr.Filename = filename - return nil, parseErr - } - return nil, &ParseError{Filename: filename, Err: err} - } - - return cfg, nil -} - // parseBytes parses YAML content into a configFile. func parseBytes(contents []byte) (*configFile, error) { cfg := &configFile{} diff --git a/pkg/config/validate.go b/pkg/config/validate.go index e5ab37ebcc..15ebe8a837 100644 --- a/pkg/config/validate.go +++ b/pkg/config/validate.go @@ -36,20 +36,6 @@ func WithProjectDir(dir string) ValidateOption { } } -// WithRequirementsFS sets the filesystem for reading python_requirements file. -func WithRequirementsFS(fsys fs.FS) ValidateOption { - return func(o *validateOptions) { - o.requirementsFS = fsys - } -} - -// WithStrictDeprecations treats deprecation warnings as errors. -func WithStrictDeprecations() ValidateOption { - return func(o *validateOptions) { - o.strictDeprecations = true - } -} - // ValidateConfigFile checks a configFile for errors. // Returns all validation errors and deprecation warnings. // Does not mutate the input. diff --git a/pkg/docker/docker_client_test.go b/pkg/docker/docker_client_test.go index 49917ae0fd..0432b5d4eb 100644 --- a/pkg/docker/docker_client_test.go +++ b/pkg/docker/docker_client_test.go @@ -325,8 +325,6 @@ func TestDockerClient(t *testing.T) { // Try to push to the mock registry err = dockerClient.Push(t.Context(), ref.String()) require.Error(t, err, "Push should fail with unreachable registry") - - assert.True(t, isNetworkError(err), "Error should be a network error, got: %q", err.Error()) }) t.Run("missing image", func(t *testing.T) { diff --git a/pkg/docker/env.go b/pkg/docker/env.go deleted file mode 100644 index 6dc5f4d802..0000000000 --- a/pkg/docker/env.go +++ /dev/null @@ -1,13 +0,0 @@ -package docker - -import "os" - -const DockerCommandEnvVarName = "R8_DOCKER_COMMAND" - -func DockerCommandFromEnvironment() string { - command := os.Getenv(DockerCommandEnvVarName) - if command == "" { - command = "docker" - } - return command -} diff --git a/pkg/docker/errors.go b/pkg/docker/errors.go index 4ebc97be13..42f40e4919 100644 --- a/pkg/docker/errors.go +++ b/pkg/docker/errors.go @@ -1,7 +1,6 @@ package docker import ( - "errors" "strings" ) @@ -48,34 +47,3 @@ func isMissingDeviceDriverError(err error) bool { strings.Contains(msg, "nvidia-container-cli: initialization error") } -// isNetworkError checks if the error is a network error. This is janky and intended for use in tests only -func isNetworkError(err error) bool { - // for both CLI and API clients, network errors are wrapped and lose the net.Error interface - // CLI client: wrapped by exec.Command as exec.ExitError - // API client: wrapped by JSON message stream processing - // Sad as it may be, we rely on string matching for common network error messages - - msg := err.Error() - networkErrorStrings := []string{ - "connection refused", - "connection reset by peer", - "dial tcp", - "EOF", - "no route to host", - "network is unreachable", - "server closed", - } - - for _, errStr := range networkErrorStrings { - if strings.Contains(msg, errStr) { - return true - } - } - - // also check wrapped errors - if unwrapped := errors.Unwrap(err); unwrapped != nil { - return isNetworkError(unwrapped) - } - - return false -} diff --git a/pkg/dockerfile/version_check.go b/pkg/dockerfile/version_check.go index 21bd322918..bdf1c69c60 100644 --- a/pkg/dockerfile/version_check.go +++ b/pkg/dockerfile/version_check.go @@ -17,20 +17,9 @@ func parse(s string) (string, string, string) { minor := m[versionRegex.SubexpIndex("minor")] patch := m[versionRegex.SubexpIndex("patch")] return major, minor, patch - -} - -func CheckMajorOnly(s string) bool { - major, minor, patch := parse(s) - return major != "" && minor == "" && patch == "" } func CheckMajorMinorOnly(s string) bool { major, minor, patch := parse(s) return major != "" && minor != "" && patch == "" } - -func CheckMajorMinorPatch(s string) bool { - major, minor, patch := parse(s) - return major != "" && minor != "" && patch != "" -} diff --git a/pkg/dockerignore/dockerignore.go b/pkg/dockerignore/dockerignore.go index 987f199c03..4d32b29039 100644 --- a/pkg/dockerignore/dockerignore.go +++ b/pkg/dockerignore/dockerignore.go @@ -29,32 +29,6 @@ func CreateMatcher(dir string) (*ignore.GitIgnore, error) { return ignore.CompileIgnoreLines(patterns...), nil } -func Walk(root string, ignoreMatcher *ignore.GitIgnore, fn filepath.WalkFunc) error { - return filepath.Walk(root, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - - // We ignore files ignored by .dockerignore - if ignoreMatcher != nil && ignoreMatcher.MatchesPath(path) { - if info.IsDir() { - return filepath.SkipDir - } - return nil - } - - if info.IsDir() && info.Name() == ".cog" { - return filepath.SkipDir - } - - if info.Name() == DockerIgnoreFilename { - return nil - } - - return fn(path, info, err) - }) -} - func readDockerIgnore(dockerIgnorePath string) ([]string, error) { var patterns []string file, err := os.Open(dockerIgnorePath) diff --git a/pkg/dockerignore/dockerignore_test.go b/pkg/dockerignore/dockerignore_test.go deleted file mode 100644 index 64ef1b9578..0000000000 --- a/pkg/dockerignore/dockerignore_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package dockerignore - -import ( - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestWalk(t *testing.T) { - dir := t.TempDir() - - predictOtherPyFilename := "predict_other.py" - predictOtherPyFilepath := filepath.Join(dir, predictOtherPyFilename) - predictOtherPyHandle, err := os.Create(predictOtherPyFilepath) - require.NoError(t, err) - predictOtherPyHandle.WriteString("import cog") - - dockerIgnorePath := filepath.Join(dir, ".dockerignore") - dockerIgnoreHandle, err := os.Create(dockerIgnorePath) - require.NoError(t, err) - dockerIgnoreHandle.WriteString(predictOtherPyFilename) - - predictPyFilename := "predict.py" - predictPyFilepath := filepath.Join(dir, predictPyFilename) - predictPyHandle, err := os.Create(predictPyFilepath) - require.NoError(t, err) - predictPyHandle.WriteString("import cog") - - matcher, err := CreateMatcher(dir) - require.NoError(t, err) - - foundFiles := []string{} - err = Walk(dir, matcher, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - - if info.IsDir() { - return nil - } - - relPath, err := filepath.Rel(dir, path) - if err != nil { - return err - } - - foundFiles = append(foundFiles, relPath) - - return nil - }) - require.NoError(t, err) - - require.Equal(t, []string{predictPyFilename}, foundFiles) -} diff --git a/pkg/predict/input.go b/pkg/predict/input.go index b4348f13b9..0e17993f31 100644 --- a/pkg/predict/input.go +++ b/pkg/predict/input.go @@ -27,10 +27,6 @@ type Input struct { type Inputs map[string]Input -func NewInputs(keyVals map[string][]string, schema *openapi3.T) (Inputs, error) { - return NewInputsForMode(keyVals, schema, false) -} - func NewInputsForMode(keyVals map[string][]string, schema *openapi3.T, isTrain bool) (Inputs, error) { schemaKey := "Input" if isTrain { @@ -133,19 +129,6 @@ func NewInputsForMode(keyVals map[string][]string, schema *openapi3.T, isTrain b return input, nil } -func NewInputsWithBaseDir(keyVals map[string]string, baseDir string) Inputs { - input := Inputs{} - for key, val := range keyVals { - if strings.HasPrefix(val, "@") { - val = filepath.Join(baseDir, val[1:]) - input[key] = Input{File: &val} - } else { - input[key] = Input{String: &val} - } - } - return input -} - func (inputs *Inputs) toMap() (map[string]any, error) { keyVals := map[string]any{} for key, input := range *inputs { diff --git a/pkg/requirements/requirements.go b/pkg/requirements/requirements.go index f679e8b68a..b3b8ac59ab 100644 --- a/pkg/requirements/requirements.go +++ b/pkg/requirements/requirements.go @@ -2,46 +2,12 @@ package requirements import ( "bufio" - "errors" "fmt" "os" - "path/filepath" "regexp" "strings" - - "github.com/replicate/cog/pkg/util/files" ) -const RequirementsFile = "requirements.txt" -const OverridesFile = "overrides.txt" - -func GenerateRequirements(tmpDir string, path string, fileName string) (string, error) { - bs, err := os.ReadFile(path) - if err != nil { - return "", err - } - requirements := string(bs) - - // Check against the old requirements - requirementsFile := filepath.Join(tmpDir, fileName) - if err := files.WriteIfDifferent(requirementsFile, requirements); err != nil { - return "", err - } - return requirementsFile, err -} - -func CurrentRequirements(tmpDir string) (string, error) { - requirementsFile := filepath.Join(tmpDir, RequirementsFile) - _, err := os.Stat(requirementsFile) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - return "", nil - } - return "", err - } - return requirementsFile, nil -} - func ReadRequirements(path string) ([]string, error) { re := regexp.MustCompile(`(?m)^\s*-e\s+\.\s*$`) @@ -185,16 +151,6 @@ func PackageName(pipRequirement string) string { return "" } -func VersionSpecifier(pipRequirement string) string { - re := regexp.MustCompile(`^[a-zA-Z0-9_\-\.]+(?:\[[^\]]+\])?\s*([<>=!~]=?\s*[^;,#\s]+(?:\s*,\s*[<>=!~]=?\s*[^;,#\s]+)*(?:\s*\|\|\s*[<>=!~]=?\s*[^;,#\s]+(?:\s*,\s*[<>=!~]=?\s*[^;,#\s]+)*)*)?`) - match := re.FindStringSubmatch(pipRequirement) - if len(match) > 1 { - // Optional: strip spaces for uniform output - return strings.ReplaceAll(match[1], " ", "") - } - return "" -} - func Versions(pipRequirement string) []string { var versions []string diff --git a/pkg/requirements/requirements_test.go b/pkg/requirements/requirements_test.go index ec85cf92e1..dabbfe9de3 100644 --- a/pkg/requirements/requirements_test.go +++ b/pkg/requirements/requirements_test.go @@ -3,25 +3,12 @@ package requirements import ( "os" "path" - "path/filepath" "strings" "testing" "github.com/stretchr/testify/require" ) -func TestPythonRequirements(t *testing.T) { - srcDir := t.TempDir() - reqFile := path.Join(srcDir, "requirements.txt") - err := os.WriteFile(reqFile, []byte("torch==2.5.1"), 0o644) - require.NoError(t, err) - - tmpDir := t.TempDir() - requirementsFile, err := GenerateRequirements(tmpDir, reqFile, RequirementsFile) - require.NoError(t, err) - require.Equal(t, filepath.Join(tmpDir, "requirements.txt"), requirementsFile) -} - func TestReadRequirements(t *testing.T) { srcDir := t.TempDir() reqFile := path.Join(srcDir, "requirements.txt") @@ -432,11 +419,6 @@ func TestReadRequirementsWithEditable(t *testing.T) { require.Equal(t, []string{"torch==2.5.1"}, requirements) } -func TestVersionSpecifier(t *testing.T) { - specifier := VersionSpecifier("mypackage>= 1.0, < 1.4 || > 2.0") - require.Equal(t, specifier, ">=1.0,<1.4||>2.0") -} - func TestPackageName(t *testing.T) { name := PackageName("mypackage>= 1.0, < 1.4 || > 2.0") require.Equal(t, name, "mypackage") diff --git a/pkg/util/console/console.go b/pkg/util/console/console.go index 7f7a2777a5..0738bdf54b 100644 --- a/pkg/util/console/console.go +++ b/pkg/util/console/console.go @@ -72,22 +72,11 @@ func (c *Console) Info(msg string) { c.log(InfoLevel, msg) } -// Success tells the user something completed successfully. -// Displays at info level with a green ✓ prefix. -func (c *Console) Success(msg string) { - c.logStyled(InfoLevel, StyleSuccess, msg) -} - // Warn tells the user that something might break. func (c *Console) Warn(msg string) { c.log(WarnLevel, msg) } -// Error tells the user that something is broken. -func (c *Console) Error(msg string) { - c.log(ErrorLevel, msg) -} - // Fatal level message, followed by exit func (c *Console) Fatal(msg string) { c.log(FatalLevel, msg) diff --git a/pkg/util/console/formatting.go b/pkg/util/console/formatting.go deleted file mode 100644 index 3184b7a280..0000000000 --- a/pkg/util/console/formatting.go +++ /dev/null @@ -1,11 +0,0 @@ -package console - -import ( - "time" - - "github.com/xeonx/timeago" -) - -func FormatTime(t time.Time) string { - return timeago.English.Format(t) -} diff --git a/pkg/util/console/global.go b/pkg/util/console/global.go index d6357def1e..5256e02052 100644 --- a/pkg/util/console/global.go +++ b/pkg/util/console/global.go @@ -33,21 +33,11 @@ func Info(msg string) { ConsoleInstance.Info(msg) } -// Success level message. -func Success(msg string) { - ConsoleInstance.Success(msg) -} - // Warn level message. func Warn(msg string) { ConsoleInstance.Warn(msg) } -// Error level message. -func Error(msg string) { - ConsoleInstance.Error(msg) -} - // Fatal level message. func Fatal(msg string) { ConsoleInstance.Fatal(msg) diff --git a/pkg/util/console/interactive.go b/pkg/util/console/interactive.go deleted file mode 100644 index 8257f7d1e5..0000000000 --- a/pkg/util/console/interactive.go +++ /dev/null @@ -1,111 +0,0 @@ -package console - -import ( - "bufio" - "fmt" - "io" - "os" - "slices" - "strings" -) - -type Interactive struct { - Prompt string - Default string - Options []string - Required bool -} - -func (i Interactive) Read() (string, error) { - if i.Default != "" && i.Options != nil && !slices.Contains(i.Options, i.Default) { - panic("Default is not an option") - } - - parens := "" - if i.Required { - parens += "required" - } - if i.Default != "" { - if parens != "" { - parens += ", " - } - parens += "default: " + i.Default - } - if i.Options != nil { - if parens != "" { - parens += ", " - } - parens += "options: " + strings.Join(i.Options, ", ") - } - if parens != "" { - parens = " (" + parens + ")" - } - - for { - fmt.Printf("%s%s: ", i.Prompt, parens) - reader := bufio.NewReader(os.Stdin) - text, err := reader.ReadString('\n') - if err != nil { - return "", err - } - text = strings.TrimSpace(text) - if text == "" && i.Default != "" { - text = i.Default - } - - if i.Required && text == "" { - Warn("Please enter a value") - continue - } - - if !i.Required && text == "" { - return "", nil - } - - if i.Options != nil { - if !slices.Contains(i.Options, text) { - Warnf("%s is not a valid option", text) - continue - } - } - - return text, nil - } -} - -type InteractiveBool struct { - Prompt string - Default bool - // NonDefaultFlag is the flag to suggest passing to do the thing which isn't default when running inside a script - NonDefaultFlag string -} - -func (i InteractiveBool) Read() (bool, error) { - defaults := "y/N" - if i.Default { - defaults = "Y/n" - } - for { - fmt.Printf("%s (%s) ", i.Prompt, defaults) - reader := bufio.NewReader(os.Stdin) - text, err := reader.ReadString('\n') - if err != nil { - // Only translate error if a flag is set - if err == io.EOF && i.NonDefaultFlag != "" { - return false, fmt.Errorf("stdin is closed. If you're running in a script, you need to pass the '%s' option", i.NonDefaultFlag) - } - return false, err - } - text = strings.ToLower(strings.TrimSpace(text)) - if text == "yes" || text == "y" { - return true, nil - } - if text == "no" || text == "n" { - return false, nil - } - if text == "" { - return i.Default, nil - } - Warn("Please enter 'y' or 'n'") - } -} diff --git a/pkg/util/console/levels.go b/pkg/util/console/levels.go index cd3512a131..001dbe3388 100644 --- a/pkg/util/console/levels.go +++ b/pkg/util/console/levels.go @@ -1,66 +1,13 @@ package console -// Mostly lifted from https://github.com/apex/log/blob/master/levels.go - -import ( - "errors" - "strings" -) - -// ErrInvalidLevel is returned if the severity level is invalid. -var ErrInvalidLevel = errors.New("invalid level") - // Level of severity. type Level int // Log levels. const ( - InvalidLevel Level = iota - 1 - DebugLevel + DebugLevel Level = iota InfoLevel WarnLevel ErrorLevel FatalLevel ) - -var levelNames = [...]string{ - DebugLevel: "debug", - InfoLevel: "info", - WarnLevel: "warn", - ErrorLevel: "error", - FatalLevel: "fatal", -} - -var levelStrings = map[string]Level{ - "debug": DebugLevel, - "info": InfoLevel, - "warn": WarnLevel, - "warning": WarnLevel, - "error": ErrorLevel, - "fatal": FatalLevel, -} - -// String implementation. -func (l Level) String() string { - return levelNames[l] -} - -// ParseLevel parses level string. -func ParseLevel(s string) (Level, error) { - l, ok := levelStrings[strings.ToLower(s)] - if !ok { - return InvalidLevel, ErrInvalidLevel - } - - return l, nil -} - -// MustParseLevel parses level string or panics. -func MustParseLevel(s string) Level { - l, err := ParseLevel(s) - if err != nil { - panic("invalid log level") - } - - return l -} diff --git a/pkg/util/console/term.go b/pkg/util/console/term.go index e044181c65..39972a8d21 100644 --- a/pkg/util/console/term.go +++ b/pkg/util/console/term.go @@ -10,18 +10,3 @@ import ( func IsTerminal() bool { return term.IsTerminal(os.Stdin.Fd()) } - -// GetWidth returns the width of the terminal (from stderr -- stdout might be piped) -// -// Returns 0 if we're not in a terminal -func GetWidth() (uint16, error) { - fd := os.Stderr.Fd() - if term.IsTerminal(fd) { - ws, err := term.GetWinsize(fd) - if err != nil { - return 0, err - } - return ws.Width, nil - } - return 0, nil -} diff --git a/pkg/util/errors.go b/pkg/util/errors.go deleted file mode 100644 index 22b50629a0..0000000000 --- a/pkg/util/errors.go +++ /dev/null @@ -1,12 +0,0 @@ -package util - -import "fmt" - -// WrapError is just a shortcut for using fmt.Errorf -// to wrap an error with a message -func WrapError(err error, message string) error { - if err == nil { - return nil - } - return fmt.Errorf("%s: %w", message, err) -} diff --git a/pkg/util/files/files.go b/pkg/util/files/files.go index 3a6b47e805..fa6ab22344 100644 --- a/pkg/util/files/files.go +++ b/pkg/util/files/files.go @@ -3,14 +3,12 @@ package files import ( "errors" "fmt" - "io" "os" "path" "strings" "github.com/mitchellh/go-homedir" "github.com/vincent-petithory/dataurl" - "golang.org/x/sys/unix" r8_path "github.com/replicate/cog/pkg/path" "github.com/replicate/cog/pkg/util/mime" @@ -41,59 +39,6 @@ func IsEmpty(path string) (bool, error) { return len(entries) == 0, nil } -func IsDir(path string) (bool, error) { - file, err := os.Stat(path) - if err != nil { - return false, err - } - return file.Mode().IsDir(), nil -} - -func IsExecutable(path string) bool { - return unix.Access(path, unix.X_OK) == nil -} - -func CopyFile(src string, dest string) error { - in, err := os.Open(src) - if err != nil { - return fmt.Errorf("Failed to open %s while copying to %s: %w", src, dest, err) - } - defer in.Close() - - out, err := os.Create(dest) - if err != nil { - return fmt.Errorf("Failed to create %s while copying %s: %w", dest, src, err) - } - defer out.Close() - - _, err = io.Copy(out, in) - if err != nil { - return fmt.Errorf("Failed to copy %s to %s: %w", src, dest, err) - } - return out.Close() -} - -func WriteIfDifferent(file, content string) error { - if _, err := os.Stat(file); err == nil { - bs, err := os.ReadFile(file) - if err != nil { - return err - } - if string(bs) == content { - return nil - } - } else if !errors.Is(err, os.ErrNotExist) { - return err - } - - // Write out a new requirements file - err := os.WriteFile(file, []byte(content), 0o644) - if err != nil { - return err - } - return nil -} - func WriteDataURLToFile(url string, destination string) (string, error) { if strings.HasPrefix(url, "data:None;base64") { url = strings.Replace(url, "data:None;base64", "data:;base64", 1) diff --git a/pkg/util/files/files_test.go b/pkg/util/files/files_test.go index 4812a5ea86..8d961c47f0 100644 --- a/pkg/util/files/files_test.go +++ b/pkg/util/files/files_test.go @@ -1,24 +1,12 @@ package files import ( - "os" "path/filepath" "testing" "github.com/stretchr/testify/require" ) -func TestIsExecutable(t *testing.T) { - dir := t.TempDir() - path := filepath.Join(dir, "test-file") - err := os.WriteFile(path, []byte{}, 0o644) - require.NoError(t, err) - - require.False(t, IsExecutable(path)) - require.NoError(t, os.Chmod(path, 0o744)) - require.True(t, IsExecutable(path)) -} - func TestWriteBadlyFormattedBase64DataURI(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "test-file") diff --git a/pkg/util/hash.go b/pkg/util/hash.go deleted file mode 100644 index f3fc7f45e9..0000000000 --- a/pkg/util/hash.go +++ /dev/null @@ -1,76 +0,0 @@ -package util - -import ( - "bytes" - "crypto/sha256" - "encoding/hex" - "errors" - "fmt" - "io" - "os" -) - -var ( - ErrInvalidRange = errors.New("Invalid byte range provided for file") -) - -func SHA256HashFile(path string) (string, error) { - hash := sha256.New() - - file, err := os.Open(path) - if err != nil { - return "", err - } - defer file.Close() - - if _, err := io.Copy(hash, file); err != nil { - return "", err - } - - return hex.EncodeToString(hash.Sum(nil)), nil -} - -func SHA256HashFileWithSaltAndRange(path string, start int, end int, salt string) (string, error) { - hash := sha256.New() - length := end - start - - if length < 0 { - return "", ErrInvalidRange - } - - file, err := os.Open(path) - if err != nil { - return "", err - } - defer file.Close() - - fileInfo, err := file.Stat() - if err != nil { - return "", err - } - - if fileInfo.Size() < int64(end) { - return "", ErrInvalidRange - } - - _, err = file.Seek(int64(start), 0) - if err != nil { - return "", fmt.Errorf("failed to open file pointer %s: %w", path, err) - } - buf := make([]byte, length) - n, err := file.Read(buf) - if err != nil { - return "", err - } - - buf = buf[:n] - var hashInput []byte - hashInput = append(hashInput, buf...) - hashInput = append(hashInput, []byte(salt)...) - - if _, err := io.Copy(hash, bytes.NewReader(hashInput)); err != nil { - return "", err - } - - return hex.EncodeToString(hash.Sum(nil)), nil -} diff --git a/pkg/util/hash_test.go b/pkg/util/hash_test.go deleted file mode 100644 index a508b0a386..0000000000 --- a/pkg/util/hash_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package util - -import ( - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestHash(t *testing.T) { - dir := t.TempDir() - path := filepath.Join(dir, "test.tmp") - d1 := []byte("hello\ngo\n") - err := os.WriteFile(path, d1, 0o644) - require.NoError(t, err) - - sha256, err := SHA256HashFile(path) - require.NoError(t, err) - require.Equal(t, "43d250d92b5dbb47f75208de8e9a9a321d23e85eed0dc3d5dfa83bc3cc5aa68c", sha256) -} - -func TestHashFileWithSaltAndRange(t *testing.T) { - dir := t.TempDir() - path := filepath.Join(dir, "test.tmp") - d1 := []byte("hello\nreplicate\nhello\n") - err := os.WriteFile(path, d1, 0o644) - require.NoError(t, err) - - _, err = SHA256HashFileWithSaltAndRange(path, 0, 60, "go\n") - require.Error(t, err) - - _, err = SHA256HashFileWithSaltAndRange(path, 23, 1, "go\n") - require.Error(t, err) - - sha256, err := SHA256HashFileWithSaltAndRange(path, 0, 6, "go\n") - require.NoError(t, err) - require.Equal(t, "43d250d92b5dbb47f75208de8e9a9a321d23e85eed0dc3d5dfa83bc3cc5aa68c", sha256) - - sha256, err = SHA256HashFileWithSaltAndRange(path, 16, 22, "go\n") - require.NoError(t, err) - require.Equal(t, "43d250d92b5dbb47f75208de8e9a9a321d23e85eed0dc3d5dfa83bc3cc5aa68c", sha256) -} diff --git a/pkg/util/ringbuffer.go b/pkg/util/ringbuffer.go deleted file mode 100644 index 311b15369f..0000000000 --- a/pkg/util/ringbuffer.go +++ /dev/null @@ -1,60 +0,0 @@ -package util - -import ( - "io" - "sync" -) - -// RingBufferWriter is a writer that writes to an underlying writer and also maintains -// a ring buffer of the last N bytes written. -type RingBufferWriter struct { - writer io.Writer - buffer []byte - size int - pos int - mu sync.Mutex -} - -// NewRingBufferWriter creates a new RingBufferWriter that writes to w and maintains -// a buffer of the last size bytes. -func NewRingBufferWriter(w io.Writer, size int) *RingBufferWriter { - return &RingBufferWriter{ - writer: w, - buffer: make([]byte, size), - size: size, - } -} - -// Write implements io.Writer interface -func (w *RingBufferWriter) Write(p []byte) (n int, err error) { - w.mu.Lock() - defer w.mu.Unlock() - - // Write to underlying writer - n, err = w.writer.Write(p) - if err != nil { - return n, err - } - - // Update ring buffer - for _, b := range p { - w.buffer[w.pos] = b - w.pos = (w.pos + 1) % w.size - } - - return n, nil -} - -// String returns the contents of the ring buffer as a string -func (w *RingBufferWriter) String() string { - w.mu.Lock() - defer w.mu.Unlock() - - // If buffer is not full, return what we have - if w.pos < w.size { - return string(w.buffer[:w.pos]) - } - - // Otherwise, return the last size bytes - return string(w.buffer[w.pos:]) + string(w.buffer[:w.pos]) -} diff --git a/pkg/util/shell/net.go b/pkg/util/shell/net.go deleted file mode 100644 index 1e60e59ecf..0000000000 --- a/pkg/util/shell/net.go +++ /dev/null @@ -1,57 +0,0 @@ -package shell - -import ( - "fmt" - "net" - "net/http" - "strconv" - "time" - - "github.com/replicate/cog/pkg/util/console" -) - -func WaitForPort(port int, timeout time.Duration) error { - start := time.Now() - for { - if PortIsOpen(port) { - return nil - } - - now := time.Now() - if now.Sub(start) > timeout { - return fmt.Errorf("Timed out") - } - - time.Sleep(100 * time.Millisecond) - } -} - -func WaitForHTTPOK(url string, timeout time.Duration) error { - start := time.Now() - console.Debugf("Waiting for %s to become accessible", url) - for { - now := time.Now() - if now.Sub(start) > timeout { - return fmt.Errorf("Timed out") - } - - time.Sleep(100 * time.Millisecond) - resp, err := http.Get(url) //#nosec G107 - if err != nil { - continue - } - if resp.StatusCode != http.StatusOK { - continue - } - console.Debugf("Got successful response from %s", url) - return nil - } -} - -func PortIsOpen(port int) bool { - conn, err := net.DialTimeout("tcp", net.JoinHostPort("", strconv.Itoa(port)), 100*time.Millisecond) - if conn != nil { - _ = conn.Close() - } - return err == nil -} diff --git a/pkg/util/shell/pipes.go b/pkg/util/shell/pipes.go deleted file mode 100644 index a42bf03886..0000000000 --- a/pkg/util/shell/pipes.go +++ /dev/null @@ -1,28 +0,0 @@ -package shell - -import ( - "bufio" - "io" -) - -type PipeFunc func() (io.ReadCloser, error) -type LogFunc func(args ...any) - -func PipeTo(pf PipeFunc, lf LogFunc) (done chan struct{}, err error) { - done = make(chan struct{}) - - pipe, err := pf() - if err != nil { - return nil, err - } - scanner := bufio.NewScanner(pipe) - go func() { - for scanner.Scan() { - line := scanner.Text() - lf(line) - } - done <- struct{}{} - }() - - return done, nil -} From 15d5ebfe348d16ec95cd043e2246aa2fc2d5af17 Mon Sep 17 00:00:00 2001 From: Michael Dwan Date: Thu, 30 Apr 2026 09:50:13 -0600 Subject: [PATCH 3/3] fmt: fix trailing newline in docker/errors.go --- pkg/docker/errors.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/docker/errors.go b/pkg/docker/errors.go index 42f40e4919..31751e69d0 100644 --- a/pkg/docker/errors.go +++ b/pkg/docker/errors.go @@ -46,4 +46,3 @@ func isMissingDeviceDriverError(err error) bool { return strings.Contains(msg, "could not select device driver") || strings.Contains(msg, "nvidia-container-cli: initialization error") } -