In [1]:
import math
from pytorch_lightning.loggers import WandbLogger
from tqdm import tqdm
import numpy as np
import torch
import pytorch_lightning as pl
from filelock import FileLock
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
import os
import torchvision
import clip
from clip_lt.utils.labels_names import labels_names
import torch.nn as nn
from clip_lt.blip.models.blip import blip_feature_extractor



In [2]:
wandb_logger = WandbLogger()
pass

[34m[1mwandb[0m: Currently logged in as: [33mrotem98[0m (use `wandb login --relogin` to force relogin)


In [3]:
dataset_dir_path = '/Volumes/black_ssd/datasets/imagenet_lt/'
# dataset_dir_path = '/Users/rotemisraeli/Documents/datasets/imagenet_lt/'

model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth'
image_size = 224

class LightningMNISTClassifier(pl.LightningModule):
    def __init__(self, config, data_dir=None):
        super(LightningMNISTClassifier, self).__init__()
        self.data_dir = data_dir
        self.lr = config['lr']
        self.batch_size = config['batch_size']
        self.blip_model = blip_feature_extractor(pretrained=model_url, image_size=image_size, vit='base')
        self.blip_model.eval()
        self.blip_model.to(self.device)
        self.text_features = torch.load('../text_features.pt')
        self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)
        self.logit_scale = (nn.Parameter(torch.ones([]) * np.log(1 / 0.07))).exp()

        self.fc = nn.Sequential(
            nn.Linear(768,1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024,1000)
        ).to(self.device)

    def forward(self, x):
        with torch.no_grad():
            image_features = self.blip_model(x, '', mode='image')[:,0]
        # print(image_features.shape,x.shape)
        out = self.fc(image_features)
        out = out#.softmax(dim=-1)

        return out

    def old_forward(self, x):
        image_features = self.clip_model.encode_image(x)
        # normalized features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # cosine similarity as logits
        logits_per_image = self.logit_scale * image_features @ self.text_features.t()
        logits_per_text = logits_per_image.t()
        probs = logits_per_image#.softmax(dim=-1)
        # print(probs.shape,probs)
        return probs

    def cross_entropy_loss(self, logits, labels):
        return F.cross_entropy(logits, labels)

    def accuracy(self, logits, labels):
        _, predicted = torch.max(logits.data, 1)
        correct = (predicted == labels).sum().item()
        accuracy = correct / len(labels)
        return torch.tensor(accuracy)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)

        self.log("ptl/train_loss", loss)
        self.log("ptl/train_accuracy", accuracy,prog_bar=True)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)
        return {"val_loss": loss, "val_accuracy": accuracy}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
        self.log("ptl/val_loss", avg_loss)
        self.log("ptl/val_accuracy", avg_acc)

    def train_dataloader(self):
        transform = transforms.Compose([
            transforms.Resize((image_size,image_size),interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ])
        train_data = torchvision.datasets.ImageFolder(self.data_dir+'train/',transform=transform)
        return DataLoader(train_data, batch_size=int(self.batch_size),num_workers=4,shuffle=True)

    def val_dataloader(self):
        transform = transforms.Compose([
            transforms.Resize((image_size,image_size),interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ])
        train_data = torchvision.datasets.ImageFolder(self.data_dir+'val/',transform=transform)
        return DataLoader(train_data, batch_size=int(self.batch_size),num_workers=4)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.fc.parameters(), lr=self.lr)
        return optimizer


def train_mnist(config):
    model = LightningMNISTClassifier(config,data_dir=dataset_dir_path)
    trainer = pl.Trainer(max_epochs=2, logger=wandb_logger)

    trainer.fit(model)

In [4]:
def train_mnist_no_tune():
    config = {
        "layer_1_size": 128,
        "layer_2_size": 256,
        "lr": 4e-3,
        "batch_size": 128
    }
    train_mnist(config)

In [5]:
train_mnist_no_tune() #115s/it 128 batch size

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name       | Type       | Params
------------------------------------------
0 | blip_model | BLIP_Base  | 223 M 
1 | fc         | Sequential | 1.8 M 
------------------------------------------
224 M     Trainable params
0         Non-trainable params
224 M     Total params
899.478   Total estimated model params size (MB)


load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

wandb: Network error (ReadTimeout), entering retry loop.


Validation: 0it [00:00, ?it/s]

wandb: Network error (ReadTimeout), entering retry loop.


Validation: 0it [00:00, ?it/s]

wandb: Network error (ReadTimeout), entering retry loop.


In [None]:
from tqdm import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
texts = []
for i in range(1000):
    label_name = labels_names[i].split(',')[0]
    if label_name[0] in 'aouie':
        texts.append(f'a photo of an {label_name}')
    else:
        texts.append(f'a photo of a {label_name}')

# texts2 = clip.tokenize(texts).to(device)
# with torch.no_grad():
#     text_features = model.encode_text(texts2)

In [None]:
torch.save(text_features,'text_features2.pt')

In [None]:
print('test')

In [None]:
texts[3]

In [None]:
clip_model, clip_preprocess = clip.load("ViT-B/32")


In [None]:
clip_model.visual.features