In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from src.modules.content_encoder import ContentEncoderV2

class VGGFeatureExtractor(nn.Module):
    def __init__(self):
        super(VGGFeatureExtractor, self).__init__()
        vgg = models.vgg16()
        pretrained = torch.load('weights/vgg16-397923af.pth')
        vgg.load_state_dict(pretrained)
        self.features = vgg.features

    def forward(self, x):
        features = self.features(x)
        features = features.permute(0, 2, 3, 1)  # Rearrange to (batch, height, width, channels)
        features = features.view(features.size(0), -1, features.size(3))  # Flatten height and width
        return features

class GRUEncoder(nn.Module):
    def __init__(self, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()
        self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src):
        """
        src: src_len x batch_size x img_channel
        outputs: src_len x batch_size x hid_dim 
        hidden: batch_size x hid_dim
        """
        embedded = self.dropout(src)
        outputs, hidden = self.rnn(embedded)
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
        return outputs, hidden

class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)
        
    def forward(self, hidden, encoder_outputs):
        """
        hidden: batch_size x hid_dim
        encoder_outputs: src_len x batch_size x hid_dim,
        outputs: batch_size x src_len
        """
        
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
  
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) 
        
        attention = self.v(energy).squeeze(2)
        
        return F.softmax(attention, dim = 1)
    
class GRUDecoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()

        self.output_dim = output_dim
        self.attention = Attention(enc_hid_dim, dec_hid_dim)
        
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden, encoder_outputs):
        """
        inputs: batch_size
        hidden: batch_size x hid_dim
        encoder_outputs: src_len x batch_size x hid_dim
        """
        input = input.unsqueeze(0)
        embedded = self.dropout(self.embedding(input))
        a = self.attention(hidden, encoder_outputs)
        a = a.unsqueeze(1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        weighted = torch.bmm(a, encoder_outputs)
        weighted = weighted.permute(1, 0, 2)
        rnn_input = torch.cat((embedded, weighted), dim = 2)
        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
        assert (output == hidden).all()
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))
        return prediction, hidden.squeeze(0), a.squeeze(1)

class Seq2Seq(nn.Module):
    def __init__(self, vocab_size, encoder_hidden, decoder_hidden, img_channel, decoder_embedded, dropout=0.1):
        super().__init__()
        
        self.encoder = GRUEncoder(img_channel, encoder_hidden, decoder_hidden, dropout)
        self.decoder = GRUDecoder(vocab_size, decoder_embedded, encoder_hidden, decoder_hidden, dropout)
        
    def forward_encoder(self, src):       
        """
        src: timestep x batch_size x channel
        hidden: batch_size x hid_dim
        encoder_outputs: src_len x batch_size x hid_dim
        """

        encoder_outputs, hidden = self.encoder(src)

        return (hidden, encoder_outputs)

    def forward_decoder(self, tgt, memory):
        """
        tgt: timestep x batch_size 
        hidden: batch_size x hid_dim
        encouder: src_len x batch_size x hid_dim
        output: batch_size x 1 x vocab_size
        """
        
        tgt = tgt[-1]
        hidden, encoder_outputs = memory
        output, hidden, _ = self.decoder(tgt, hidden, encoder_outputs)
        output = output.unsqueeze(1)
        
        return output, (hidden, encoder_outputs)

    def forward(self, src, trg):
        """
        src: time_step x batch_size
        trg: time_step x batch_size
        outputs: batch_size x time_step x vocab_size
        """

        batch_size = src.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        device = src.device

        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(device)
        encoder_outputs, hidden = self.encoder(src)
                
        for t in range(trg_len):
            input = trg[t] 
            output, hidden, _ = self.decoder(input, hidden, encoder_outputs)
            
            outputs[t] = output
            
        outputs = outputs.transpose(0, 1).contiguous()

        return outputs

class OCRModel(nn.Module):
    def __init__(self, feature_extractor, seqmodel):
        super(OCRModel, self).__init__()
        self.feature_extractor = feature_extractor
        self.seqmodel = seqmodel
        
    def forward(self, x, target):
        features = self.feature_extractor(x)
        features = features.permute(1,0,2)
        output = self.seqmodel(features, target)
        return output


# vgg_extractor = VGGFeatureExtractor()
# inp = torch.rand((1,3,64,256))
# out = vgg_extractor(inp)
# out.shape # torch.Size([1, 16, 512])

# encoder = GRUEncoder(512, 512, 512, 0.1)
# inp = torch.rand(16,1,512)
# out,hidden = encoder(inp)
# out.shape, hidden.shape # (torch.Size([16, 1, 1024]), torch.Size([1, 512]))

# decoder = GRUDecoder(48, 256, 256, 256, 0.1)
# inp = torch.randint(0,48, size=(1,))
# hidden = torch.rand(1, 256)
# encoder_out = torch.rand(16,1,512)
# out,hidden,_ = decoder(inp, hidden, encoder_out)
# out.shape, hidden.shape

seqmodel = Seq2Seq(48, 512, 512, 512, 512)
inp = torch.rand((16,1, 512))
tgt = torch.randint(0,48, size=(16,1))
out = seqmodel(inp, tgt)
out.shape

  from .autonotebook import tqdm as notebook_tqdm


torch.Size([1, 16, 48])

In [2]:
# Hyperparameters
input_size = 512  # VGG feature size
encoder_hidden = 512
decoder_hidden = 512
img_channel = 512
decoder_emb = 512
vocab_size = 1024 #len('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz')  # Number of classes (characters)
num_layers = 1

# Instantiate the model
vgg_extractor = VGGFeatureExtractor()
seqmodel = Seq2Seq(vocab_size, encoder_hidden, decoder_hidden, img_channel, decoder_emb)
model = OCRModel(vgg_extractor, seqmodel)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# images = torch.rand(1,3,64,256)
# labels = torch.randint(0,48, size=(16,1))
# output = model(images, labels)
# output.shape

torch.Size([1, 16, 1024])

In [5]:
vcs = len('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz')
vcs

52

In [6]:
content_encoder = ContentEncoderV2('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz',
                    in_channels=1024,
                    n_heads=8,
                    d_head=128,)
inp = torch.randint(0, vcs+2, size=(1,16))
out = content_encoder(inp)
out.shape

torch.Size([1, 16, 1024])

In [8]:
from src.criterion import ContentPerceptualLoss, SupConLoss

In [None]:
# # Training loop
# num_epochs = 10
# for epoch in range(num_epochs):
#     for batch in dataloader:
#         images = batch['image'].to(device)
#         labels = batch['label'].to(device)
        
#         # Forward pass
#         outputs = model(images)
        
#         # Compute loss
#         loss = criterion(outputs, labels)
        
#         # Backward pass and optimization
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
        
#     print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')