In [1]:
from transformers import BertTokenizer, BertModel




In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')


In [2]:
en = []
with open('/kaggle/input/machinetranslationenvi/train.en', 'r', encoding='utf-8') as file:
	for line in file:
		en.append(line.strip())  # strip() removes trailing newline characters

vi = []
with open('/kaggle/input/machinetranslationenvi/train.vi', 'r', encoding='utf-8') as file:
	for line in file:
		vi.append(line.strip())  # strip() removes trailing newline characters

In [3]:
en_valid = []
with open('/kaggle/input/machinetranslationenvi/tst2012.en', 'r', encoding='utf-8') as file:
	for line in file:
		en_valid.append(line.strip())  # strip() removes trailing newline characters

vi_valid = []
with open('/kaggle/input/machinetranslationenvi/tst2012.vi', 'r', encoding='utf-8') as file:
	for line in file:
		vi_valid.append(line.strip())  # strip() removes trailing newline characters

In [4]:
print(en[:5] )
print(vi[:5] )
print(en_valid[:5] )
print(vi_valid[:5] )

['Rachel Pike : The science behind a climate headline', 'In 4 minutes , atmospheric chemist Rachel Pike provides a glimpse of the massive scientific effort behind the bold headlines on climate change , with her team -- one of thousands who contributed -- taking a risky flight over the rainforest in pursuit of data on a key molecule .', 'I &apos;d like to talk to you today about the scale of the scientific effort that goes into making the headlines you see in the paper .', 'Headlines that look like this when they have to do with climate change , and headlines that look like this when they have to do with air quality or smog .', 'They are both two branches of the same field of atmospheric science .']
['Khoa học đằng sau một tiêu đề về khí hậu', 'Trong 4 phút , chuyên gia hoá học khí quyển Rachel Pike giới thiệu sơ lược về những nỗ lực khoa học miệt mài đằng sau những tiêu đề táo bạo về biến đổi khí hậu , cùng với đoàn nghiên cứu của mình -- hàng ngàn người đã cống hiến cho dự án này -- m

In [5]:
train_data_src = en[2269:(2269+4096)]
train_data_trg= vi[2269:(2269+4096)]
valid_data_src = en_valid[269:(269+512)]
valid_data_trg= vi_valid[269:(269+512)]
test_data_src = en_valid[4:(4+256)]
test_data_trg= vi_valid[4:(4+256)]



In [6]:
# from data import SentenceDataset
# from model import TransformerMT
from transformers import BertTokenizerFast
# from w2v import WordEmbedding
import torch.optim as optim
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import time

In [7]:
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from transformers import BertTokenizerFast



class SentenceDataset(Dataset):
    def __init__(self, src_sentence, tgt_sentence, tokenizer, max_length):
        self.src = src_sentence 
        self.tgt = tgt_sentence
        self.tokenizer = tokenizer
        self.max_length = max_length 

    def get_tokenized_sentences(self, sentence):
        tokenized_sentence = self.tokenizer(sentence, padding='max_length', truncation=True, return_tensors="pt", max_length=self.max_length)
        return tokenized_sentence['input_ids'], tokenized_sentence['attention_mask']

    def __len__(self):
        return len(self.src)
    
    def __getitem__(self, idx):
        tokenized_src, src_mask = self.get_tokenized_sentences(self.src[idx])
        tokenized_tgt, tgt_mask = self.get_tokenized_sentences(self.tgt[idx])
        return {
            'src': tokenized_src.squeeze(0),
            'src_mask': src_mask.squeeze(0),
            'tgt': tokenized_tgt.squeeze(0),
            'tgt_mask': tgt_mask.squeeze(0)
        }

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# postiional embedding in w2v
# self-attention
# multi-head attention ? the difference between self-attention and multi-head attention in implementation
# can we merge self-attention and multi-head attention into one class?
# feed forward network
# layer normalization
class LayerNorm(nn.Module):
	def __init__(self, ndim, eps: float = 1e-5):
		super(LayerNorm, self).__init__()
		self.gamma = nn.Parameter(torch.ones(ndim))
		self.beta = nn.Parameter(torch.zeros(ndim))
		self.eps = eps
	
	def forward(self, x):
		mean = x.mean(-1, keepdim=True)
		std = x.std(-1, keepdim=True)
		return self.gamma * (x - mean) / (std + self.eps) + self.beta

class SelfAttention(nn.Module):
	def __init__(self, embed_size, nhead, dropout):
		super(SelfAttention, self).__init__()
		self.dropout = nn.Dropout(dropout)
		self.p_qkv = nn.Linear(embed_size, embed_size*3)
		torch.nn.init.xavier_uniform_(self.p_qkv.weight)
		self.p_proj = nn.Linear(embed_size, embed_size)
		torch.nn.init.xavier_uniform_(self.p_proj.weight)
		self.nhead = nhead
	
	def forward(self, x , attmask):   
		'''
		1. input q, k, v, attention mask x [bs, nhead, seq_len, embed_size], attmask [bs, seqlen]
		2. calculate the attention score
		3. add & norm ( dropout residual connection before add )
		4. feed forward network
		5. add & norm ( dropout residual connection before add )
		ensure that output have shape [bs, seqlen, embed_size*n_head]

		'''

		x = self.p_qkv(x) # [bs, seq_len, embed_size*3]
		q, k, v = torch.chunk(x, 3, dim = -1) # q, k, v [bs, seq_len, embed_size]
		bs, sqlen, embed_size = q.size()

		q = q.view(bs, sqlen, self.nhead, embed_size//self.nhead).transpose(1, 2)
		k = k.view(bs, sqlen, self.nhead, embed_size//self.nhead).transpose(1, 2)
		v = v.view(bs, sqlen, self.nhead, embed_size//self.nhead).transpose(1, 2)

		# configure mask
		# because att mask can be 2D if it is used in encoder, and 3D if it is used in decoder
		if attmask.dim() == 2:
			attmask = attmask.unsqueeze(1).unsqueeze(2) # [bs, 1, 1, seq_len]
			# attmask [bs, seqlen, embed_size], we need to translate  [bs, nhead, seqlen, embed_size//nhead]
		
		

		att_score = (q @ k.transpose(-2, -1)) * (1.0/math.sqrt(q.size(-1))) # [bs, nhead, seq_len, seq_len]
		att_score = att_score.masked_fill(attmask == 0, -10000)
		att_score = F.softmax(att_score, dim = -1)
		att_score = self.dropout(att_score) # [bs, nhead, seq_len, seq_len]
		y = att_score @ v # [bs, nhead, seqlen, embed_size//nhead]
		y = y.transpose(1, 2).contiguous().view(bs, sqlen, embed_size)
		
		# is y need to be go through a linear layer?
		y = self.p_proj(y)
		return y, attmask

class CrossAttention(nn.Module):
	def __init__(self,  embed_size, nhead):
			super(CrossAttention, self).__init__()
			self.nhead = nhead
			self.W_q = nn.Linear(embed_size, embed_size)
			torch.nn.init.xavier_uniform_(self.W_q.weight)

			self.W_k = nn.Linear(embed_size, embed_size)
			torch.nn.init.xavier_uniform_(self.W_k.weight)
			
			self.W_v = nn.Linear(embed_size, embed_size)
			torch.nn.init.xavier_uniform_(self.W_v.weight)

			self.W_o = nn.Linear(embed_size, embed_size)
			torch.nn.init.xavier_uniform_(self.W_o.weight)

			self.dropout = nn.Dropout(0.1)
			
	def forward(self, tgt, enc, mask):
		# tgt [bs, seqlen, embed_size]
		# enc [bs, seqlen, embed_size]
		# mask [bs, seqlen]
		bs, tgt_size, embed_size = tgt.size()
		src_size = enc.size(1)

		q = self.W_q(tgt)
		k = self.W_k(enc)
		v = self.W_v(enc)

		# [bs, seqlen, embed_size] -> [bs, seqlen, nhead, embed_size//nhead] -> [bs, nhead, seqlen, embed_size//nhead
		q = q.view(bs, tgt_size, self.nhead, embed_size//self.nhead).transpose(1, 2)
		k = k.view(bs, src_size, self.nhead, embed_size//self.nhead).transpose(1, 2)
		v = v.view(bs, src_size, self.nhead, embed_size//self.nhead).transpose(1, 2)

		att_score = (q @ k.transpose(-2, -1)) * (1.0/math.sqrt(q.size(-1)))
		att_score = att_score.masked_fill(mask == 0, -10000)
		att_score = F.softmax(att_score, dim = -1) # [bs, seqlen, seqlen]
		y = att_score @ v
		y = y.transpose(1, 2).contiguous().view(bs, tgt_size, -1)
		y = self.W_o(y)
		return self.dropout(y)

class FFN(nn.Module):
	def __init__(self, embed_size):
			super().__init__()
			self.linear1  = nn.Linear(embed_size, embed_size*4)
			torch.nn.init.xavier_uniform_(self.linear1.weight)
			self.linear2 = nn.Linear(embed_size*4, embed_size)
			torch.nn.init.xavier_uniform_(self.linear2.weight)
			self.gelu = nn.GELU()
			self.dropout = nn.Dropout(0.1)

		
	def forward(self, x):
			x = self.linear1(x)
			x = self.gelu(x)
			x = self.dropout(x)
			x = self.linear2(x)
			x = self.dropout(x)
			return x

class EncoderLayer(nn.Module):
	def __init__(self, embed_size, nhead, dropout,  bias=True, eps=1e-06):
		super().__init__()    
		self.selfattn = SelfAttention(embed_size, nhead, dropout)
		self.ffn = FFN(embed_size)
		self.dropout = nn.Dropout(0.1)
		self.norm = LayerNorm(embed_size )

	def forward(self, x, mask):
		# x  [bs, seqlen, embed_size]
		_x = x
		x, mask = self.selfattn(x, mask)
		x = _x +  self.dropout(x)
		x = self.norm(x)

		x = x + self.dropout(self.ffn(x))
		x = self.norm(x)
		return x, mask

class DecoderLayer(nn.Module):
	def __init__(self, config):
		super(DecoderLayer, self).__init__()
		# we use masked self-attention in decoder, we just need to pass the exact mask to the self-attention
		self.maskedselfattn = SelfAttention(config['embed_size'], config['nhead'], config['dropout'])
		self.norm1 = LayerNorm(config['embed_size'])

		self.crossatt = CrossAttention(config['embed_size'], config['nhead'])
		self.norm2 = LayerNorm(config['embed_size'])    

		self.ffn = FFN(config['embed_size'])
		self.norm3 = LayerNorm(config['embed_size'])

		
		self.dropout = nn.Dropout(0.1)

	def _init_casual_mask(self, mask):
		# mask [bs, seqlen]
		mask = mask.unsqueeze(-1).expand(-1, -1, mask.size(-1)).unsqueeze(1) # [bs, 1, seqlen, seqlen]
		mask = torch.tril(mask, diagonal=0)
		return mask # [bs, 1, seqlen, seqlen]
	
	def forward(self, tgt, encoder_output, tgt_mask, src_mask):
		# tgt [bs, seqlen, embed_size]
		# encoder_output [bs, seqlen, embed_size]
		# tgt_mask [bs, seqlen]
		# src_mask [bs, seqlen]
		# tgt_mask = self._init_casual_mask(tgt_mask)
		_tgt = tgt
		tgt, tgt_mask = self.maskedselfattn(tgt, tgt_mask)
		tgt = _tgt + self.dropout(tgt)
		# tgt = tgt + self.dropout(self.maskedselfattn(tgt, tgt_mask))
		tgt = self.norm1(tgt)
		
		# src_mask = src_mask.unsqueeze(-1).expand(-1, -1, tgt_mask.size(-1)).unsqueeze(1) # [bs, 1, seqlen, seqlen ]
		tgt = tgt + self.dropout(self.crossatt(tgt, encoder_output, src_mask))
		tgt = self.norm2(tgt)
		
		tgt = tgt + self.dropout(self.ffn(tgt))
		tgt = self.norm3(tgt)
		
		return tgt, tgt_mask
	
			


class TransformerMT(nn.Module):
	def __init__(self, args):
		super(TransformerMT, self).__init__()
		self.args = args
		self.encoder = EncoderLayer(args['embed_size'], args['nhead'], args['dropout'])
		self.decoder = DecoderLayer(args)
	
	def forward(self, src, tgt, src_mask, tgt_mask):
		# src [bs, seqlen, embed_size]
		# tgt [bs, seqlen, embed_size]
		# src_mask [bs, seqlen]
		# tgt_mask [bs, seqlen]
		enc_output, src_mask = self.encoder(src, src_mask)
		dec_output, tgt_mask = self.decoder(tgt, enc_output, tgt_mask, src_mask)
		return enc_output, dec_output, src_mask, tgt_mask
	

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers import  BertModel

class Word2Vec(nn.Module):
	def __init__(self, vocab_size, embed_size, BERT = False): 
		super(Word2Vec, self).__init__()
		if BERT:
			model = BertModel.from_pretrained('bert-base-multilingual-cased')
			self.embeddings = model.embeddings.word_embeddings
			self.embeddings.requires_grad_(False)
		else:	
			self.embeddings = nn.Embedding(vocab_size, embed_size)
			torch.nn.init.normal_(self.embeddings.weight, mean=0, std=0.02)
	def forward(self, x):
		x = self.embeddings(x)
		return x
	


class PositionalEmbedding(nn.Module):
	def __init__(self,embed_size, max_len, device):
		super(PositionalEmbedding, self).__init__()
		self.encoding = torch.zeros(max_len, embed_size, device=device, requires_grad=False)
		pos = torch.arange(0, max_len).float().	unsqueeze(1)
		_2i = torch.arange(0, embed_size, 2).float()
		self.encoding[:, 0::2] = torch.sin(pos/ torch.pow(10000, _2i/ embed_size))
		self.encoding[:, 1::2] = torch.cos(pos/ torch.pow(10000, _2i/ embed_size))

	def forward(self, x):
		# bs, seqlen, embed_dim = x.size()
		# pe_tensor = torch.zeros(seqlen, embed_dim)
		# sin = [torch.sin(pos/ torch.pow(10000, torch.arange(0, embed_dim, 2)/ embed_dim)) for  pos in self.pos]
		# cos = [torch.cos(pos/ torch.pow(10000, torch.arange(1, embed_dim, 2)/ embed_dim)) for pos in self.pos]
		# pe_tensor[:, 0::2] = sin
		# pe_tensor[:, 1::2] = cos
		# pe_tensor = pe_tensor.unsqueeze(0).expand(bs, seqlen, embed_dim)
		bs, seqlen, embed_dim = x.size()
		return self.encoding[:seqlen, :].expand(bs, seqlen, embed_dim)

class WordEmbedding(nn.Module):
	def __init__(self, vocab_size, embed_size, max_len, device, BERT=False):
		super(WordEmbedding, self).__init__()
		self.word2vec = Word2Vec(vocab_size, embed_size, BERT)
		self.positional_embedding = PositionalEmbedding( embed_size, max_len, device)
		
	def forward(self, x):
		x = self.word2vec(x)
		x = x + self.positional_embedding(x)
		return x

In [10]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]



config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

In [11]:
train_data = SentenceDataset(en, vi, tokenizer, max_length=128)
valid_data = SentenceDataset(valid_data_src, valid_data_trg, tokenizer, max_length=128)
test_data = SentenceDataset(test_data_src, test_data_trg, tokenizer, max_length=128)

In [12]:
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=32, shuffle=False)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [14]:
tokenizer.vocab_size

119547

In [15]:
args = {
	'embed_size': 768,
	'num_layers': 8,
	'max_len' : 128,
	'nhead': 12,
	'dropout': 0.1,
	'vocab_size': tokenizer.vocab_size,
	'BERT': True,
	'device': device
}

In [16]:
class MultiLayerTransformerMT(nn.Module):
	def __init__(self, args):
		super(MultiLayerTransformerMT, self).__init__()
		self.embeddings = WordEmbedding(args['vocab_size'], args['embed_size'], args['max_len'], args['device'], args['BERT'])
		self.transformer = nn.ModuleList([TransformerMT(args) for _ in range(args['num_layers'])])
		self.head = nn.Linear(args['embed_size'], args['vocab_size'])
		torch.nn.init.xavier_uniform_(self.head.weight)

	def forward(self, src, tgt, src_mask, tgt_mask):
		src = self.embeddings(src)
		tgt = self.embeddings(tgt)
		for layer in self.transformer:
			src, tgt, src_mask, tgt_mask = layer(src, tgt, src_mask, tgt_mask)
		tgt = self.head(tgt)
		return tgt.reshape(-1, tgt.size(-1))



In [17]:
model = MultiLayerTransformerMT(args)
model.to(device)
# model = torch.compile(model)
optim = torch.optim.AdamW(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9,  weight_decay=0.001)	

model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]

In [18]:
critertion = nn.CrossEntropyLoss().to(device)

In [19]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [24]:
len(train_loader)*5

20835

In [25]:
def train (model, data, optimizer, critertion, device, epochs=1):
	model.train()
	start = time.time()
	running_loss = 0
	for j in range(epochs):
		for i, batch in enumerate(data):
			src = batch['src'].to(device)
			tgt = batch['tgt'].to(device)
			src_mask = batch['src_mask'].to(device)
			tgt_mask = batch['tgt_mask'].to(device)
			optimizer.zero_grad()
			with torch.autocast(device_type='cuda', dtype=torch.float16):
				output = model(src, tgt[:, :-1], src_mask, tgt_mask[:, :-1])
				# output = output.view(-1, output.size(-1))
				loss = critertion(output, tgt[:, 1:].contiguous().view(-1))
			loss.backward()
			torch.nn.utils.clip_grad_norm_(v_1, max_norm=1.0, norm_type=2)
			optimizer.step()
			torch.cuda.synchronize()
			running_loss += (loss.item())
			if (i+1) % 50 == 0:
				print(f'Epoch: {j}, step: {i}, Loss: {loss.item()/i}')
	end = time.time()
	print(f'Time taken: {end-start}')

In [26]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(total_params)

224236539


In [28]:
train(model, train_loader, optim, critertion, device, epochs=5)

Epoch: 0, step: 49, Loss: 0.03772320552748077
Epoch: 0, step: 99, Loss: 0.021058894167042742
Epoch: 0, step: 149, Loss: 0.013322587781304481
Epoch: 0, step: 199, Loss: 0.010958612863741928
Epoch: 0, step: 249, Loss: 0.009823525287061331
Epoch: 0, step: 299, Loss: 0.010607447512572426
Epoch: 0, step: 349, Loss: 0.005456005946271399
Epoch: 0, step: 399, Loss: 0.006107738442289501
Epoch: 0, step: 449, Loss: 0.004755276083149729
Epoch: 0, step: 499, Loss: 0.004024295864219895
Epoch: 0, step: 549, Loss: 0.003843783029441625
Epoch: 0, step: 599, Loss: 0.003705552105911586
Epoch: 0, step: 649, Loss: 0.0032274285891023
Epoch: 0, step: 699, Loss: 0.003130686641251069
Epoch: 0, step: 749, Loss: 0.0026788514192018394
Epoch: 0, step: 799, Loss: 0.0025470972956345883
Epoch: 0, step: 849, Loss: 0.0031042034409773504
Epoch: 0, step: 899, Loss: 0.0026117957075923117
Epoch: 0, step: 949, Loss: 0.001995181760998998
Epoch: 0, step: 999, Loss: 0.0027010927687178147
Epoch: 0, step: 1049, Loss: 0.0020815717

KeyboardInterrupt: 

In [29]:
import time
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

In [30]:
def evaluation(model, data, critertion, device):
	model.eval()
	running_loss = 0
	t0 = time.time()
	bleu_score = 0
	total_samples = 0
	for i, batch in enumerate(data):
			src = batch['src'].to(device)
			tgt = batch['tgt'].to(device)
			src_mask = batch['src_mask'].to(device)
			tgt_mask = batch['tgt_mask'].to(device)
			with torch.no_grad():
				output = model(src, tgt[:, :-1], src_mask, tgt_mask[:, :-1]) # output = [batch_size, tgt_len-1, vocab_size]
				loss = critertion(output, tgt[:, 1:].contiguous().view(-1))

				topk_prob, topk_ids = output.topk(k=3, dim=-1)
				id = torch.multinomial(topk_prob, num_samples=1)
				xcol = torch.gather(topk_ids, -1, id)
				generated_tokens = xcol.squeeze(-1).view(-1, src.size(1)-1)
				# Calculate BLEU for each sentence and accumulate
				for ref, pred in zip(tgt[:, :-1], generated_tokens):
					bleu_score += sentence_bleu([ref.cpu().numpy().tolist()], pred.cpu().numpy().tolist(), smoothing_function=SmoothingFunction().method4)
			total_samples += src.size(0)				
			running_loss += loss.item()
	t1 = time.time()
	print(f"Training time: {t1-t0}, Loss: {running_loss/len(data)}, BLEU: {bleu_score/total_samples}")

In [31]:
evaluation(model, valid_loader, critertion, device)

RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

In [None]:
evaluation(model, test_loader, critertion, device)

In [None]:
tokenizer.cls_token_id

In [None]:
def generate_translation(model, sentences, device, tokenizer, max_length):
	model.eval()
	src = tokenizer(sentences, padding="max_length", truncation=True, return_tensors='pt', max_length=max_length)
	src_mask = src['attention_mask'].to(device)
	src = src['input_ids'].to(device)

	sample_rng = torch.Generator(device=device)
	sample_rng.manual_seed(123)    
    
	tgt = torch.Tensor([[tokenizer.cls_token_id]]).long().to(device)
	for _ in range(max_length):
		tgt_mask = torch.ones(src.size(0), tgt.size(1), device=device).tril(diagonal=0)
		with torch.no_grad():
			output = model(src, tgt, src_mask, tgt_mask)
			next_token = output[-1, :]
			next_token = F.softmax(next_token, dim=-1)
			topk_prob, topk_idx = torch.topk(next_token, k=8, dim=-1)
			id  = torch.multinomial(topk_prob, num_samples=1, generator=sample_rng)
			actual_token = topk_idx.gather(dim=-1, index=id).unsqueeze(-1)
			tgt = torch.cat([tgt,actual_token], dim = 1)
	return tgt

In [None]:
tokenizer.decode(generate_translation(model, "Hello", device, tokenizer, 32).squeeze(0))

In [None]:
tokenizer.decode(generate_translation(model, "That , I think , is movement .", device, tokenizer, 32).squeeze(0))

In [None]:
tokenizer.decode(generate_translation(model, "Thank you very much .", device, tokenizer, 32).squeeze(0))

In [None]:
tokenizer.decode(generate_translation(model, "They &apos;re kept in bare cells like this for 23 hours a day .", device, tokenizer, 32).squeeze(0))

In [None]:
torch.save(model.state_dict(), "/kaggle/working/transf")

In [54]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(total_params)

15654779


In [2]:
import torch

In [3]:
a = torch.randn(3,4)
a[:, 1].shape

torch.Size([3])

In [4]:
a

tensor([[ 0.7775,  1.0295,  0.0726, -0.1331],
        [ 0.2277, -0.3164,  0.5758, -0.4550],
        [-1.9309,  1.7030, -0.3719, -0.1745]])

In [10]:
model = torch.nn.Embedding(10,3)

In [13]:
topk_prob, topk_idx = torch.topk(a, k=1, dim=-1)
print(topk_prob,"\n", topk_prob.shape,  "\n", model(topk_idx).shape, '\n', model(topk_idx.squeeze(-1)).shape)

tensor([[1.0295],

        [0.5758],

        [1.7030]]) 

 torch.Size([3, 1]) 

 torch.Size([3, 1, 3]) 

 torch.Size([3, 3])


In [None]:
torch.save(model.state_dict(), "/kaggle/working/transf")