In [1]:
!pip install -q wandb torchmetrics

[K     |████████████████████████████████| 1.8 MB 7.5 MB/s 
[K     |████████████████████████████████| 398 kB 52.9 MB/s 
[K     |████████████████████████████████| 181 kB 50.9 MB/s 
[K     |████████████████████████████████| 144 kB 46.2 MB/s 
[K     |████████████████████████████████| 63 kB 1.7 MB/s 
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [2]:
import pickle
import random
import gzip
import pytz
from datetime import datetime
import numpy as np
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

## Preprocessing Embeddings

In [3]:
compound_embedding_path = "/content/drive/MyDrive/LangOn/semanticsNTU/tencent_compounds_embeddings.txt.gz"

vocabs = []
embs = []

with gzip.open(compound_embedding_path, "rt") as fin:
  n_vocab, n_hdim = fin.readline().split(" ")
  n_hdim = int(n_hdim)
  for i in tqdm(range(int(n_vocab))):
    toks = fin.readline().strip().split(" ")
    word = toks[0]
    emb = np.array([float(x) for x in toks[1:]])
    vocabs.append(word)
    embs.append(emb)
embs = embs / np.linalg.norm(embs, axis=1)[:, np.newaxis]

  0%|          | 0/119797 [00:00<?, ?it/s]

In [4]:
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [8]:
sum(1 for x in vocabs if len(x)==2), sum(1 for x in vocabs if len(x)==4)

(43394, 76398)

In [9]:
class ConvmorphNNDataset(Dataset):
  def __init__(self, vocabs, embs):
    self.build_dataset(vocabs, embs)

  def build_dataset(self, vocabs, embs):
    vocab_map = {vocab: idx for idx, vocab in enumerate(vocabs)}
    self.data = []
    for idx in tqdm(range(len(vocabs)), desc="building dataset"):
      word = vocabs[idx]
      const1 = word[:2]
      const2 = word[2:]
      if len(word) < 4: continue
      if not (const1 in vocab_map and const2 in vocab_map):
        continue
      emb = embs[idx]
      const1_vec = embs[vocab_map[const2]]
      const2_vec = embs[vocab_map[const1]]      
      self.data.append(dict(
          word=word,
          const1=const1, const2=const2,
          word_vec=emb,
          const1_vec=const1_vec, const2_vec=const2_vec
      ))
  
  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    return self.data[idx]

In [10]:
cm_ds = ConvmorphNNDataset(vocabs, embs)
len(cm_ds)

building dataset:   0%|          | 0/119797 [00:00<?, ?it/s]

75957

In [12]:
class ConvmorphDatasetArc(Dataset):
  def __init__(self, cm_dataset, idxs):
    self.build_dataset(cm_dataset, idxs)

  def build_dataset(self, ds, idxs):
    self.data = []
    for serial, idx in enumerate(idxs):
      data_x = ds[idx]
      inputX = np.concatenate([
          data_x["const1_vec"], data_x["const2_vec"]
      ]).reshape(1, 20, 20)
      target = data_x["word_vec"]
      self.data.append(dict(
          word_id=serial,
          inputX=torch.tensor(inputX, dtype=torch.float32),
          target=torch.tensor(target, dtype=torch.float32)
      ))
  
  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    return self.data[idx]


In [15]:
N = len(cm_ds)
rng = np.random.RandomState(123)
random_split = np.arange(N)
rng.shuffle(random_split)
train_idxs = random_split[:-6042]
test_idxs = random_split[-6042:]
train_ds = ConvmorphDatasetArc(cm_ds, train_idxs)
test_ds = ConvmorphDatasetArc(cm_ds, test_idxs)
train_eval_ds = ConvmorphDatasetArc(cm_ds, train_idxs[-len(test_ds):])
print(len(train_idxs), len(test_idxs), len(train_eval_ds))

69915 6042 6042


In [16]:
with open("drive/MyDrive/LangOn/convmorph/split_inidices_char4.pkl", "wb") as fout:
  pickle.dump({"train_idxs": train_idxs, "test_idxs": test_idxs}, fout)

## Prepare Dataset

In [17]:
from dataclasses import dataclass
import torch.nn.functional as F

@dataclass
class ConvmorphArcModelOutput:
  loss: torch.tensor
  pred_vec: torch.tensor

class ConvmorphArcModel(nn.Module):
  def __init__(self, dropout=0.1):
    super().__init__()

    self.conv1 = nn.Conv2d(1, 128, 3)
    self.norm1 = nn.LayerNorm([128, 18, 18])
    self.conv2 = nn.Conv2d(128, 128, 3)
    self.norm2 = nn.LayerNorm([128, 16, 16])    
    self.conv3 = nn.Conv2d(128, 128, 3)
    self.norm3 = nn.LayerNorm([128, 14, 14])    
    self.conv4 = nn.Conv2d(128, 128, 3)
    self.norm4 = nn.LayerNorm([128, 12, 12])
    self.conv5 = nn.Conv2d(128, 128, 3)
    self.norm5 = nn.LayerNorm([128, 10, 10])
    
    self.pool1 = nn.MaxPool2d(2)
    self.fn1 = nn.Linear(128*5*5, 3000)
    self.drop1 = nn.Dropout(p=dropout)
    # self.fn2 = nn.Linear(512, 200)   
    # self.fn2 = nn.Linear(2048, 1024)
    # self.drop2 = nn.Dropout(p=dropout)
    # self.fn3 = nn.Linear(1024, 512)
    # self.drop3 = nn.Dropout(p=dropout)
    # self.fn4 = nn.Linear(512, 200)
    self.fn4 = nn.Linear(3000, 200)
  
  def forward(self, inputX, target=None, **kwargs):
    
    z = F.relu(self.norm1(self.conv1(inputX)), inplace=True)
    z = F.relu(self.norm2(self.conv2(z)), inplace=True)
    z = F.relu(self.norm3(self.conv3(z)), inplace=True)
    z = F.relu(self.norm4(self.conv4(z)), inplace=True)    
    z = F.relu(self.norm5(self.conv5(z)), inplace=True)    
    z = self.pool1(z).view(-1, 128*5*5)
    o = self.drop1(torch.tanh(self.fn1(z)))    
    # o = self.drop2(torch.tanh(self.fn2(o)))
    # o = self.drop3(torch.tanh(self.fn3(o)))
    o = self.fn4(o)
    pred_vec = o

    if target is not None:
      loss_fct = nn.MSELoss()
      loss = loss_fct(pred_vec, target)
      return ConvmorphArcModelOutput(loss, pred_vec)
    else:      
      return ConvmorphArcModelOutput(float('nan'), pred_vec)
    

In [18]:
def compute_accuracy(data_loader, test_embs):
  n_correct = 0
  n_items = 0
  model.eval()
  for batch_x in tqdm(data_loader):
      with torch.no_grad():
        batch_x = {k: v.to("cuda") for k, v in batch_x.items()}
        word_ids = batch_x["word_id"].cpu().numpy()
        pred_vec = model(**batch_x).pred_vec      
        preds = torch.argmax(torch.matmul(pred_vec, test_embs.transpose(1, 0)), dim=1).cpu().numpy()
        n_correct += np.array(word_ids==preds, dtype=np.int32).sum()
        n_items += len(preds)
  return n_correct/n_items

In [19]:
from torch.optim.lr_scheduler import LambdaLR
# From https://github.com/huggingface/transformers/blob/v4.17.0/src/transformers/optimization.py
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):    
    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
        )

    return LambdaLR(optimizer, lr_lambda, last_epoch)

In [20]:
eval_embs = torch.tensor(np.vstack([x["target"] for x in test_ds]), dtype=torch.float32).to("cuda")
train_eval_embs = torch.tensor(np.vstack([x["target"] for x in train_eval_ds]), dtype=torch.float32).to("cuda")

In [21]:
config = dict(
    lr=5e-4,
    scheduler="linear",
    batch_size=128,
    epochs=100,
    dropout=0.1,
    note='128x5x5->fn:3000,200'
)

run_name = datetime.now(pytz.timezone('Asia/Taipei')).strftime("%m%d-%H%M-arc")
wandb.init(project="convmorph", name=run_name, config=config, save_code=True)

[34m[1mwandb[0m: Currently logged in as: [33mseantyh[0m (use `wandb login --relogin` to force relogin)


In [22]:
from itertools import islice
epochs = config["epochs"]
lr = config["lr"]
batch_size = config["batch_size"]

##   Dataset
## -----------
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

##    Model
## ----------
model = ConvmorphArcModel(dropout=config["dropout"]).to("cuda")
optimizer = optim.AdamW(model.parameters(), lr=lr)
total_steps = epochs*len(train_loader)
scheduler = get_linear_schedule_with_warmup(optimizer, 0, total_steps)

torch.backends.cudnn.deterministic = True
torch.manual_seed(123)
random.seed(123)

train_eploss_vec = []
eval_eploss_vec = []

iter_counter = 0
pbar = tqdm(total=total_steps)
for epoch_i in range(epochs):
  train_loss_vec = []
  eval_loss_vec = []
  model.train()  
  pbar.set_description(f"Epoch {epoch_i}/Train")
  for batch_x in train_loader:
    pbar.update(1)
    optimizer.zero_grad()  
    batch_x = {k: v.to("cuda") for k, v in batch_x.items()}
    loss = model(**batch_x).loss
    loss.backward()
    optimizer.step()
    scheduler.step()
    train_loss_vec.append(loss.item())    
    if iter_counter % 100 == 0:
      wandb.log({"train/loss": loss.item()}, step=iter_counter)
    iter_counter += 1
  train_eploss_vec.append(np.mean(train_loss_vec))  

  pbar.set_description(f"Epoch {epoch_i}/Eval")
  for batch_x in eval_loader:
    model.eval()
    with torch.no_grad():
      batch_x = {k: v.to("cuda") for k, v in batch_x.items()}
      loss = model(**batch_x).loss
      eval_loss_vec.append(loss.item())    
  eval_eploss_vec.append(np.mean(eval_loss_vec))  
  train_acc = compute_accuracy(DataLoader(train_eval_ds, batch_size=128, shuffle=False), train_eval_embs)
  test_acc = compute_accuracy(DataLoader(test_ds, batch_size=128, shuffle=False), eval_embs)
  wandb.log({
      "train/epoch-loss": train_eploss_vec[-1],
      "eval/epoch-loss": eval_eploss_vec[-1],
      "eval/train-acc": train_acc,
      "eval/test-acc": test_acc,
      "trainer/lr": scheduler.get_last_lr()[0]
  }, step=iter_counter)
  print(f"train/eval loss: {train_eploss_vec[-1]}/{eval_eploss_vec[-1]}")
wandb.finish()
pbar.close()

  0%|          | 0/54700 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.00973683074235889/0.004768426539764429


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.004439038016801016/0.004481772445918371


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.004253602815686403/0.0042298100791716324


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.004147347113052877/0.0038894032137856507


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0034784064325972222/0.003265069108844424


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0029940421086901174/0.002883894335051688


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0027346981538653756/0.0027253451892950884


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0025977224845377903/0.002599395336195206


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0025086851945484035/0.002563827598351054


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0024404874696333295/0.0024832583682533973


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0023814190431945773/0.0024211008955414095


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.002336707383724546/0.002369985231780447


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.002296057618708372/0.0023490952007705346


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0022633426292361846/0.00233504482700179


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0022343908027120927/0.002314019793023666


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0022112400167991135/0.00228458703107511


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.00219106261696755/0.00228637145482935


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.002173992961822197/0.0022730433411197737


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0021576466008176314/0.0022602275566896424


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0021425630270300853/0.0022479726176243275


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.002126093555439325/0.0022384157637134194


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0021140195833267116/0.002250532120039376


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0020997095684958467/0.0022433470100319632


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0020871657235003067/0.0022282245045062155


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.002075411499796309/0.0022376171400537714


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0020633428153229477/0.0022196357507103435


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.002051033004709746/0.0022180225253881267


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.002038498423351707/0.00221892939589452


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0020267215961338424/0.0022229269961826503


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0020170281175098345/0.0022192258281089985


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0020050946840168516/0.002221997817590212


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.001995331307462803/0.002202810709907984


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.001985550984498488/0.002211043693629714


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.001975259151977148/0.0022040719146995493


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0019653480090507367/0.0021977099725821367


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0019563225087738667/0.0021898415385900685


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.001947056781330318/0.0021966272761346772


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.001937735340961509/0.002193929123071333


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0019293860872863524/0.0021958719201696417


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.001922144298101188/0.0021941351102820286


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0019139890357685128/0.0021813629185392833


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0019057835543773373/0.0021908329314707467


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0018982658234222279/0.0021849139399516084


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.001891515925577697/0.0021930227182262265


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0018834378087205278/0.0021970492768256613


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0018769356888638882/0.0021843065430099764


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0018694912276266506/0.0021886965162896863


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0018627963812470137/0.002203017041514007


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0018560039535015004/0.0021913716821776084


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0018503621443629373/0.002193856290735615


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0018429757718173905/0.0021946769459949187


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0018373540880573978/0.0021901650179643184


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0018307988172638788/0.0021908980464407555


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0018244925922572858/0.0021928221152241654


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0018185241453492608/0.002198549084520588


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0018134102828875592/0.002201001455735726


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0018069821531617064/0.0022063608873092258


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.001801392460842932/0.0021913404537675283


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0017964956648579681/0.0021881425297275805


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0017906319159738569/0.0021940806036582217


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0017855388475871445/0.0022048309280459457


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.001781466146088265/0.0021862219097480797


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0017748788687429464/0.002188039376051165


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.001769704140222187/0.002200835534798292


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0017646310241872503/0.002198386049713008


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0017593018529711928/0.002193303487729281


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0017542773459729836/0.0021981905350306383


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0017499976275979084/0.0021953091394001


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0017446712518393503/0.002195521728329671


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0017405943757764163/0.002205352546297945


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0017355945781962962/0.002198641527987396


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0017308761600670582/0.002203662351045447


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0017270251871912418/0.0021964619857802368


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0017221423144362035/0.002209564942556123


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0017177300392335086/0.002198295507696457


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0017132875869234503/0.002198635925500033


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.001708842417901174/0.00220296701688009


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0017044382432468877/0.0022046209681623927


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0017005153110172511/0.002207469229081956


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016958585091554977/0.002204391566920094


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016913555985592456/0.0022052312997402623


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016878128774937453/0.002207201352575794


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016832709034915053/0.002216852493196105


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016790809772969982/0.0022135562709687897


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016748057595340083/0.002213504985168887


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016711995874359605/0.002211565224570222


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016669971701115287/0.0022077028455290324


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016630064703309585/0.0022094734207106135


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016592613038721015/0.002214137236781729


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.001655631010178851/0.0022135016915854067


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016516707345955687/0.002216519981933137


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016476505859950778/0.0022099230894430852


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016434098590911116/0.0022178652143338695


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016402367536928607/0.002215188287664205


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016359240989331871/0.0022177820113332323


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016323556164551997/0.002217330703084978


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016285603641663462/0.002214365037313352


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016246282602396455/0.0022174721428503594


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016207365998896128/0.0022168905998114496


  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

train/eval loss: 0.0016172144998723072/0.0022164896169366934



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/epoch-loss,█▇▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eval/test-acc,▁▁▅▆▇▇▇█████████████████████████████████
eval/train-acc,▁▁▄▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇█████████████████████
train/epoch-loss,█▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▇▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▂▂▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁
trainer/lr,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
eval/epoch-loss,0.00222
eval/test-acc,0.738
eval/train-acc,0.93959
train/epoch-loss,0.00162
train/loss,0.00163
trainer/lr,0.0


In [24]:
import os
out_dir = "drive/MyDrive/LangOn/convmorph/" + run_name
os.makedirs(out_dir, exist_ok=True)
torch.save(model.state_dict(), out_dir+"/model.pth")