In [1]:
import torchvision
import torch

In [2]:
efficientnet = torchvision.models.efficientnet_b0(weights=torchvision.models.EfficientNet_B0_Weights.IMAGENET1K_V1)
efficientnet.classifier[1] = torch.nn.Identity()

In [3]:
from pytorch_model_summary import summary

In [4]:
print(summary(efficientnet, torch.zeros(69, 3, 224, 224)))

-------------------------------------------------------------------------------
           Layer (type)           Output Shape         Param #     Tr. Param #
               Conv2d-1     [69, 32, 112, 112]             864             864
          BatchNorm2d-2     [69, 32, 112, 112]              64              64
                 SiLU-3     [69, 32, 112, 112]               0               0
               MBConv-4     [69, 16, 112, 112]           1,448           1,448
               MBConv-5       [69, 24, 56, 56]           6,004           6,004
               MBConv-6       [69, 24, 56, 56]          10,710          10,710
               MBConv-7       [69, 40, 28, 28]          15,350          15,350
               MBConv-8       [69, 40, 28, 28]          31,290          31,290
               MBConv-9       [69, 80, 14, 14]          37,130          37,130
              MBConv-10       [69, 80, 14, 14]         102,900         102,900
              MBConv-11       [69, 80, 14, 14]     

In [5]:
decoder_layer = torch.nn.TransformerDecoderLayer(d_model=512, nhead=8)
transformer_decoder = torch.nn.TransformerDecoder(decoder_layer, num_layers=6)

In [6]:
efficientnet(torch.zeros(69, 3, 224, 224))

tensor([[ 0.1626, -0.0023,  0.2371,  ...,  0.4563,  0.0158,  0.0000],
        [ 0.3846,  0.0000,  0.0000,  ...,  0.0548, -0.0942, -0.1344],
        [ 0.2714, -0.0000,  0.5801,  ...,  0.1179,  0.1266, -0.0000],
        ...,
        [-0.0000, -0.0657,  0.3875,  ..., -0.1731,  0.0000, -0.1454],
        [ 0.1431,  0.0428, -0.0000,  ..., -0.0696, -0.1450,  0.1198],
        [ 0.0000, -0.0023,  0.2371,  ...,  0.0000,  0.0158,  0.1852]],
       grad_fn=<AsStridedBackward0>)

In [10]:
recognizer_mini = ChessRecognizer(512, 4, 4)
chr2idx = {'>': 0, '<': 1, 'A': 2, 'B': 3}
recognizer_full_mini = ChessRecognizerFull(chr2idx, recognizer_mini)
#print(recognizer_full.embedding_size)
recognizer_full_mini(torch.zeros(2, 3, 224, 224), ['AB'] * 2, max_len=4)
#recognizer_full.convert_text_to_tensor(['AB>'] * 4, 90)
recognizer_full_mini.compute_loss(torch.zeros(2, 3, 224, 224), ['AB'] * 2, max_len=4)
recognizer_full_mini.generate_output(torch.zeros(3, 224, 224), max_len=4)

''

In [29]:
from train import trainer, evaluator, recognizer_full

In [30]:
trainer.run(train_dataloader, max_epochs=2)

Running validation
Row accuracy: 0.0, Token accuracy: 0.11219946571682991


State:
	iteration: 18
	epoch: 2
	epoch_length: 9
	max_epochs: 2
	output: 7.544007301330566
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

In [51]:
print(summary(recognizer_full.to(device), dataset[0][0].to(device).float().unsqueeze(0)/255.0, ['bbbbbb']))

----------------------------------------------------------------------------
           Layer (type)        Output Shape         Param #     Tr. Param #
            Embedding-1        [1, 90, 256]           9,216           9,216
   PositionalEncoding-2        [90, 1, 256]               0               0
      ChessRecognizer-3        [90, 1, 256]      10,721,916       6,714,368
Total params: 10,731,132
Trainable params: 6,723,584
Non-trainable params: 4,007,548
----------------------------------------------------------------------------


In [77]:
params = list(recognizer_full.parameters())

In [81]:
sum(torch.numel(param) for param in params) - 10721916-9216

0

In [31]:
def generate_output(self, input_img, max_len: int=90) -> str:
    with torch.no_grad():
        current_seq = self.idx2chr[2] + 'rrrrr'
        curr_idx = 0
        pp = None
        while len(current_seq) < max_len:
            current_seq_tensor, curr_seq_mask = self.convert_text_to_tensor([current_seq], max_len)
            recognizer_output = self.recognizer(torch.unsqueeze(input_img, 0), current_seq_tensor, curr_seq_mask | False) # comes out as max_len x 1 x d_model
            recognizer_output = self.convert_output_tensor_to_logits(recognizer_output)
            pre_logits = recognizer_output[-1, :, :] # take the last element of the sequence, now it is 1 x d_model
            print("Actual")
            print(pre_logits)
            if pp is not None:
                print("Diff")
                print(pre_logits - pp)
            pp = pre_logits
            filtered_logits = top_k_top_p_filtering(pre_logits, top_k=0, top_p=0.85) # disable top_k, set top_p = 0.85
            filtered_logits = torch.nn.functional.softmax(pre_logits, dim=1) # still 1 x d_model
            next_token = torch.multinomial(filtered_logits, 1, replacement=True).item()
            if next_token == 1:
                pass
            #break
            current_seq += self.idx2chr[next_token]
        return current_seq[1:]

In [32]:
N = 0
recognizer_full.eval()
output, desired_output = generate_output(recognizer_full, dataset[N][0].to(device).float() / 255.0, max_len=20), dataset[N][1]
print(output)
print(desired_output)
recognizer_full.compute_loss(dataset[N][0].to(device).float().unsqueeze(0), ['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1'])

Actual
tensor([[   0.2632,   23.2920,    8.0402,   23.1371,   23.2585,   21.0327,
           14.0558,   21.0366,   23.8089,   23.1485,   22.6810,   12.9351,
           17.6448,   10.3701,   16.2547,   17.3641,   23.9682,   20.6138,
           10.7859,   13.4753,   12.8044,  -16.1548,   21.5267,   25.6231,
          -25.4533,    5.3132, -126.5298,  -29.7871,   -8.6416,    6.2138,
          -22.0798,   16.7549,   19.5347,   19.2026,   11.8437,   13.3179]],
       device='cuda:0')
Actual
tensor([[   0.2636,   23.2922,    8.0400,   23.1371,   23.2584,   21.0327,
           14.0556,   21.0364,   23.8088,   23.1486,   22.6810,   12.9352,
           17.6446,   10.3698,   16.2545,   17.3640,   23.9682,   20.6139,
           10.7859,   13.4755,   12.8043,  -16.1549,   21.5269,   25.6229,
          -25.4533,    5.3135, -126.5296,  -29.7870,   -8.6416,    6.2139,
          -22.0799,   16.7548,   19.5349,   19.2026,   11.8437,   13.3179]],
       device='cuda:0')
Diff
tensor([[ 4.0436e-04,  1.8311

tensor(8.3286, device='cuda:0', grad_fn=<MeanBackward0>)

In [110]:
def compute_loss(self, input_img, desired_output_texts: List[str], max_len: int=8):
        desired_output_texts = [txt + self.idx2chr[0] for txt in desired_output_texts]
        shifted_outputs = [(self.idx2chr[1] + output)[:-1] for output in desired_output_texts]
        # mask comes out as batch_size x max_len x max_len
        current_seq_tensor, mask = self.convert_text_to_tensor(shifted_outputs, max_len)
        # Index at 0 for the second coordinate, because it's just repeated along that dimension. now it is batch_size x max_len
        mask = 1 - mask[:, 0, :].float()
        idx_list = self.convert_text_to_idx_list(desired_output_texts, max_len) # comes out as batch_size x max_len
        ideal_output = torch.nn.functional.one_hot(idx_list).float() * 20000 - 10000 # comes out as batch_size x max_len x vocab_size
        # need batch_size x vocab_size x max_len
        loss = torch.nn.CrossEntropyLoss(reduction='none')(torch.permute(ideal_output, (0, 2, 1)), idx_list) * mask
        print(loss)
        print(ideal_output)
        return torch.sum(loss) / torch.sum(mask)

compute_loss(recognizer_full, dataset[N][0].to(device).float().unsqueeze(0), ['rrrbbr'])

tensor([[-0., -0., -0., -0., -0., -0., -0., -0.]], device='cuda:0')
tensor([[[-10000., -10000., -10000., -10000., -10000.,  10000., -10000.],
         [-10000., -10000., -10000., -10000., -10000.,  10000., -10000.],
         [-10000., -10000., -10000., -10000., -10000.,  10000., -10000.],
         [-10000., -10000., -10000., -10000., -10000., -10000.,  10000.],
         [-10000., -10000., -10000., -10000., -10000., -10000.,  10000.],
         [-10000., -10000., -10000., -10000., -10000.,  10000., -10000.],
         [ 10000., -10000., -10000., -10000., -10000., -10000., -10000.],
         [ 10000., -10000., -10000., -10000., -10000., -10000., -10000.]]],
       device='cuda:0')


tensor(0., device='cuda:0')

In [None]:
for i, img in enumerate(dataset.images):
    dataset.images[i] = T.CenterCrop(224)(img)

In [None]:
transform = T.Compose([T.Pad(50), T.RandAugment(num_ops=2), T.Resize(224), T.CenterCrop(224)])

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(transform(dataset[1571][0]).numpy().transpose((1, 2, 0)))