In [1]:
import math
from pytorch_lightning.loggers import WandbLogger
from tqdm import tqdm
import numpy as np
import torch
import pytorch_lightning as pl
from filelock import FileLock
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
import os
import torchvision
import clip
from clip_lt.utils.labels_names import labels_names
import torch.nn as nn

wandb_logger = WandbLogger()



In [2]:
dataset_dir_path = '/Volumes/black_ssd/datasets/imagenet_lt/'
dataset_dir_path = '/Users/rotemisraeli/Documents/datasets/TinyImageNet/'

class LightningMNISTClassifier(pl.LightningModule):
    def __init__(self, config, data_dir=None):
        super(LightningMNISTClassifier, self).__init__()
        self.data_dir = data_dir
        self.lr = config['lr']
        self.batch_size = config['batch_size']
        self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
        self.text_features = torch.load('../text_features.pt')
        self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)
        self.logit_scale = (nn.Parameter(torch.ones([]) * np.log(1 / 0.07))).exp()


    def forward(self, x):
        image_features = self.clip_model.encode_image(x)
        # normalized features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # cosine similarity as logits
        logits_per_image = self.logit_scale * image_features @ self.text_features.t()
        logits_per_text = logits_per_image.t()
        probs = logits_per_image#.softmax(dim=-1)
        # print(probs.shape,probs)
        return probs

    def cross_entropy_loss(self, logits, labels):
        return F.cross_entropy(logits, labels)

    def accuracy(self, logits, labels):
        _, predicted = torch.max(logits.data, 1)
        correct = (predicted == labels).sum().item()
        accuracy = correct / len(labels)
        return torch.tensor(accuracy)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)

        self.log("ptl/train_loss", loss)
        self.log("ptl/train_accuracy", accuracy,prog_bar=True)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)
        return {"val_loss": loss, "val_accuracy": accuracy}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
        self.log("ptl/val_loss", avg_loss)
        self.log("ptl/val_accuracy", avg_acc)

    def train_dataloader(self):
        # TRANSFORM_IMG = transforms.Compose([
        #     transforms.Resize(224),
        #     transforms.CenterCrop(224),
        #     transforms.ToTensor(),
        #     transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                          std=[0.229, 0.224, 0.225] )
        # ])
        train_data = torchvision.datasets.ImageFolder(self.data_dir+'train/',transform=self.clip_preprocess)
        return DataLoader(train_data, batch_size=int(self.batch_size),num_workers=1,shuffle=True)

    def val_dataloader(self):
        train_data = torchvision.datasets.ImageFolder(self.data_dir+'val/',transform=self.clip_preprocess)
        return DataLoader(train_data, batch_size=int(self.batch_size),num_workers=1)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer


def train_mnist(config):
    model = LightningMNISTClassifier(config,data_dir=dataset_dir_path)
    trainer = pl.Trainer(max_epochs=1, logger=wandb_logger)

    trainer.fit(model)

In [3]:
def train_mnist_no_tune():
    config = {
        "layer_1_size": 128,
        "layer_2_size": 256,
        "lr": 1e-5,
        "batch_size": 32
    }
    train_mnist(config)

In [3]:
train_mnist_no_tune()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
[34m[1mwandb[0m: Currently logged in as: [33mrotem98[0m (use `wandb login --relogin` to force relogin)



  | Name       | Type | Params
------------------------------------
0 | clip_model | CLIP | 151 M 
------------------------------------
151 M     Trainable params
0         Non-trainable params
151 M     Total params
605.109   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

Process wandb_internal:
Traceback (most recent call last):
  File "/Users/rotemisraeli/opt/anaconda3/envs/env1/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/Users/rotemisraeli/opt/anaconda3/envs/env1/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/rotemisraeli/opt/anaconda3/envs/env1/lib/python3.8/site-packages/wandb/sdk/internal/internal.py", line 162, in wandb_internal
    thread.join()
  File "/Users/rotemisraeli/opt/anaconda3/envs/env1/lib/python3.8/threading.py", line 1011, in join
    self._wait_for_tstate_lock()
  File "/Users/rotemisraeli/opt/anaconda3/envs/env1/lib/python3.8/threading.py", line 1027, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
KeyboardInterrupt
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  F

Traceback (most recent call last):
  File "/Users/rotemisraeli/opt/anaconda3/envs/env1/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 156, in training_step
    return self.training_type_plugin.training_step(*args)
  File "/Users/rotemisraeli/opt/anaconda3/envs/env1/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 125, in training_step
    return self.lightning_module.training_step(*args, **kwargs)
  File "/var/folders/qq/bg7j_rts29s3hwdtklnnfg0c0000gn/T/ipykernel_32874/4043055602.py", line 41, in training_step
    logits = self.forward(x)
  File "/var/folders/qq/bg7j_rts29s3hwdtklnnfg0c0000gn/T/ipykernel_32874/4043055602.py", line 17, in forward
    image_features = self.clip_model.encode_image(x)
  File "/Users/rotemisraeli/opt/anaconda3/envs/env1/lib/python3.8/site-packages/clip/model.py", line 337, in encode_image
    return self.visual(image.type(self.dtype))
  File "/Users/rotemisraeli/opt/anaconda

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/Users/rotemisraeli/opt/anaconda3/envs/env1/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 156, in training_step
    return self.training_type_plugin.training_step(*args)
  File "/Users/rotemisraeli/opt/anaconda3/envs/env1/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 125, in training_step
    return self.lightning_module.training_step(*args, **kwargs)
  File "/var/folders/qq/bg7j_rts29s3hwdtklnnfg0c0000gn/T/ipykernel_32874/4043055602.py", line 41, in training_step
    logits = self.forward(x)
  File "/var/folders/qq/bg7j_rts29s3hwdtklnnfg0c0000gn/T/ipykernel_32874/4043055602.py", line 17, in forward
    image_features = self.clip_model.encode_image(x)
  File "/Users/rotemisraeli/opt/anaconda3/envs/env1/lib/python3.8/site-packages/clip/model.py", line 337, in encode_image
    return self.visual(image.type(self.dtype))
  File "/Users/rotemisraeli/opt/anaconda

TypeError: object of type 'NoneType' has no len()

Error in callback <function _WandbInit._pause_backend at 0x7fb32d221ca0> (for post_run_cell):


Exception: The wandb backend process has shutdown

In [None]:
from tqdm import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
texts = []
for i in range(1000):
    label_name = labels_names[i].split(',')[0]
    if label_name[0] in 'aouie':
        texts.append(f'a photo of an {label_name}')
    else:
        texts.append(f'a photo of a {label_name}')

texts2 = clip.tokenize(texts).to(device)
with torch.no_grad():
    text_features = model.encode_text(texts2)

In [None]:
torch.save(text_features,'text_features.pt')

In [None]:
print('test')