In [2]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import os
import pandas as pd
from transformers import AutoTokenizer, AutoModel, AutoProcessor

In [2]:
device_type = "cpu"
device = torch.device(device_type)

In [4]:
captions_path = "Clean-1Sentences_withComma.txt"
images_path = "HindiImages"

In [5]:
df = pd.read_csv(captions_path)
df.head(20)

Unnamed: 0,image,caption
0,1000268201_693b08cb0e,गुलाबी पोशाक में बच्चा प्रवेश के रास्ते में सी...
1,1001773457_577c3a7d70,काला कुत्ता और चित्तीदार कुत्ता लड़ रहे हैं
2,1002674143_1b742ab4b8,पेंट में ढकी छोटी लड़की कटोरे में अपने हाथों स...
3,1003163366_44323f5815,आदमी बेंच पर लेट जाता है जबकि उसका कुत्ता उसके...
4,1007129816_e794419615,नारंगी टोपी में आदमी कुछ पर अभिनीत।
5,1007320043_627395c3d8,रस्सी के जाल पर खेलता बच्चा।
6,1009434119_febe49276a,सफेद बाड़ से घिरे घास के बगीचे में काला और सफे...
7,1012212859_01547e3f17,कुत्ता किनारे के पास अपना सिर हिलाता है उसके ब...
8,1015118661_980735411b,शहर में स्टोनी की दीवार के सामने लड़का मुस्कुर...
9,1015584366_dfcec3c85a,लॉग पर काला कुत्ता छलांग लगाता है।


In [5]:
class HindiFlickerDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, captions_file, transform=None):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file, nrows=500)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir, self.df.iloc[idx, 0])
        img_name += ".jpg"
        image = Image.open(img_name)
        if self.transform:
            image = self.transform(image)

        caption = self.df.iloc[idx, 1]

        return image, caption

In [7]:
class CosineSimilarityLoss(torch.nn.Module):
    def __init__(self):
        super(CosineSimilarityLoss, self).__init__()
        self.cosine_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

    def forward(self, image_features, text_features):
        # Cosine similarity ranges from -1 to 1, we need to shift it to 0 - 2 range
        loss = 1 - self.cosine_similarity(image_features, text_features)
        return loss.mean()

In [8]:
def train(model, dataloader, optimizer, loss_function, device):
    model.train()
    for epoch in range(5):
        idx = 0
        total_loss = 0.0
        for inputs, captions in dataloader:
            print("Processing batch {}".format(idx))
            # images = images.to(device)
            # input_ids = captions['input_ids'].squeeze().to(device)
            # attention_mask = captions['attention_mask'].squeeze().to(device)
            optimizer.zero_grad()
            # print("Hello")
            # text_features = model.get_text_features(**captions)
            # device = "mps"
            inputs = (inputs - inputs.min()) / (inputs.max() - inputs.min())
            inputs = processor(images=inputs, return_tensors="pt").to(device)
            # model.to(device)
            image_features = model.get_image_features(**inputs)
            # device = "cpu"
            # print(type(image_features), len(image_features[0]))
            caption = tokenizer(list(captions), padding="max_length", truncation=True, max_length=64, return_tensors="pt")
            # model.to(device)
            text_features = model.get_text_features(**caption)
            # print(type(image_features), type(text_features))
            # print(type(text_features), len(text_features[0]))
            loss = loss_function(image_features, text_features)
            total_loss += loss
            del image_features
            del text_features
            del inputs
            del caption
            loss.backward()
            optimizer.step()
            if device_type != "cpu":
                torch.cuda.empty_cache()
            idx += 1
        print("Mean Loss after {} epochs: {}".format(epoch+1, total_loss/len(dataloader)))

In [9]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [10]:
dataset = HindiFlickerDataset(root_dir=images_path, captions_file=captions_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=False)
print(len(dataloader))

63


In [11]:
model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384", low_cpu_mem_usage=True)
model.to(device)
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384", low_cpu_mem_usage=True, do_rescale=False)
tokenizer =  AutoTokenizer.from_pretrained("google/siglip-so400m-patch14-384",  low_cpu_mem_usage=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_function = CosineSimilarityLoss()

In [None]:
train(model, dataloader, optimizer, loss_function, device)

Processing batch 0
Processing batch 1
Processing batch 2
Processing batch 3
Processing batch 4
Processing batch 5
Processing batch 6
Processing batch 7
Processing batch 8
Processing batch 9
Processing batch 10
Processing batch 11
Processing batch 12
Processing batch 13
Processing batch 14
Processing batch 15
Processing batch 16
Processing batch 17
Processing batch 18
Processing batch 19
Processing batch 20
Processing batch 21
Processing batch 22
Processing batch 23
Processing batch 24
Processing batch 25
Processing batch 26
Processing batch 27
Processing batch 28
Processing batch 29
Processing batch 30
Processing batch 31
Processing batch 32
Processing batch 33
Processing batch 34
Processing batch 35
Processing batch 36
Processing batch 37
Processing batch 38
Processing batch 39
Processing batch 40
Processing batch 41
Processing batch 42
Processing batch 43
Processing batch 44
Processing batch 45
Processing batch 46
Processing batch 47
Processing batch 48
Processing batch 49
Processing

In [None]:
torch.save(model.state_dict(), "./trained_model.pth")