# Main Training Book for Our PCFG

In [1]:
# Import dependencies
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler

import pcfg_models as pcfg_models
import pcfg_loader as loader
import source as source


scaler = GradScaler()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

Device: cuda


  scaler = GradScaler()


In [2]:
# Import our data yielder
data_folder = 'training'
loader_params = {
    'batch_size': 1,
    'shuffle': True,
    'only_inputs': True,
    'print_steps': False,
    'moves_per_step': 1,
    'max_steps': 1,
    'p_use_base': 0.1,
}
data_yielder = loader.get_pcfg_datayielder(data_folder, loader_params)

In [3]:
# Instantiate our Encoder classifier
dsl_list = data_yielder.label_ops
embedding_model_fp = 'vit_11-21-24_100k_vF.pth'
em = source.embedding.load_ViT(embedding_model_fp, device='cpu')

pcfg_encoder = pcfg_models.PCFG_Encoder(
    n_embd = em.embed_dim, 
    n_head = 4, 
    n_layer = 6, 
    ff_hd = int(em.embed_dim * 2), 
    dropout = 0, # This may cause issues if you make this non-zero
    block_size = 35, 
    dsl_mapping = dsl_list, 
    embedding_model_fp = embedding_model_fp,
    freeze_emb_model=True, 
    device = device
)
pcfg_encoder = pcfg_encoder.to(device)
em = None   # Clear the model from memory

# Use this to load in a pre-trained model instead
# filename = ?
# pcfg_encoder = pcfg_models.PCFG_Encoder.load_model(f"trained_pcfg_models/{filename}", print_statements=True, device=device)

  checkpoint = torch.load(path, map_location=torch.device(device))


Vision Transformer instantiated with 84,064 parameters using Sinusoidal encodings.
Vision Transformer instantiated with 84,064 parameters using Sinusoidal encodings.
PCFG encoder instantiated with 52,288 parameters.


In [4]:
def embed_input(input, pcfg_encoder, use_grads, device):
    x = torch.zeros((len(input), pcfg_encoder.model_params['n_embd']), device=device)
    special_tokens = pcfg_encoder.get_special_tokens(device)   # cls_dsl, cls_obj, pad, sep
    
    try:
        first_pad = input[:16].index('<PAD>')
    except ValueError:
        first_pad = 16
    
    for i, obj in enumerate(input):
        if obj == "<PAD>":
            x[i, :] = special_tokens[2, :]
        elif obj == "<SEP>":
            x[i, :] = special_tokens[3, :]
        elif isinstance(obj, source.ARC_Object):
            # x[i, :] = torch.zeros((1, 64), device=device)
            obj.set_embedding(pcfg_encoder.embedding_model, use_grads=use_grads)
            x[i, :] = obj.embedding.to(device)
        else:
            raise NameError("Input contains object that is not '<PAD>', '<SEP>', or an ARC_Object.")
    x = torch.cat((special_tokens[:2, :], x), dim=0)
    return x, first_pad

In [5]:
def compute_kl_loss(dsl_cls, obj_att, dsl_label, obj_label, temp=0.2):
    """
    Compute multi-label multi-class loss using KL divergence.

    Args:
        dsl_cls: Tensor of size (len(dsl)) - predicted logits for DSL classes.
        obj_att: Tensor of size (len(obj_att)) - predicted logits for object attention.
        dsl_label: Tensor of integers up to len(dsl), representing ground truth labels for DSL classes.
        obj_label: Tensor of integers up to len(obj_att), representing ground truth labels for object attention.

    Returns:
        Total loss as a scalar tensor.
    """
    # One-hot encode labels for multi-label setup
    dsl_target = torch.zeros_like(dsl_cls, dtype=torch.float)
    dsl_target[dsl_label] = 1.0
    dsl_target = dsl_target / dsl_target.sum(dim=-1, keepdim=True)  # Normalize to probabilities
    obj_target = torch.zeros_like(obj_att, dtype=torch.float)
    obj_target[obj_label] = 1.0
    obj_target = obj_target / obj_target.sum(dim=-1, keepdim=True)  # Normalize to probabilities

    # Convert predictions to log-probabilities
    dsl_log_prob = F.log_softmax(dsl_cls, dim=-1)
    obj_log_prob = F.log_softmax(obj_att, dim=-1)

    # Compute KL divergence loss
    dsl_loss = F.kl_div(dsl_log_prob, dsl_target, reduction='batchmean')  # log_target=False by default
    obj_loss = F.kl_div(obj_log_prob, obj_target, reduction='batchmean')

    return dsl_loss, obj_loss

In [6]:
# Loop through data
train_params = {
    "epochs": 1,
    "lr": 1e-4,
    "print_frequency": 100,
}

optim = torch.optim.AdamW(pcfg_encoder.parameters(), lr=train_params['lr'])
avg_losses = []
seen_labels = [0] * len(dsl_list)
dsl_cls = None
obj_att = None

for epoch in range(train_params['epochs']):
    active = True
    step, dsl_loss_sum, obj_loss_sum, total_loss_sum = 0, 0, 0, 0

    while active:
        try:
            key, input, label, obj_idx = next(data_yielder)
        except StopIteration:
            active = False
            break
        except loader.SampleExhausted:
            continue

        seen_labels[label[0]] += 1
        with torch.amp.autocast('cuda'):
            x, first_pad = embed_input(input, pcfg_encoder, use_grads=not pcfg_encoder.freeze_emb_model, device=device)
            dsl_cls, obj_att = pcfg_encoder(x)
            obj_att = obj_att[2:first_pad + 2]  # Since we have 2 special cls tokens at the start

            # Compute individual losses
            dsl_loss, obj_loss = compute_kl_loss(
                dsl_cls, 
                obj_att, 
                torch.tensor(label, device=device), 
                torch.tensor(obj_idx, device=device)
            )

            optim.zero_grad(set_to_none=True)
            # total_loss = dsl_loss + obj_loss
            total_loss = dsl_loss
            scaler.scale(total_loss).backward()
            scaler.step(optim)
            scaler.update()

        # Accumulate losses
        total_loss_sum += total_loss.item()
        dsl_loss_sum += dsl_loss.item()
        obj_loss_sum += obj_loss.item()
        step += 1

        if step % train_params['print_frequency'] == 0:
            avg_total_loss = total_loss_sum / train_params['print_frequency']
            avg_dsl_loss = dsl_loss_sum / train_params['print_frequency']
            avg_obj_loss = obj_loss_sum / train_params['print_frequency']

            avg_losses.append((avg_total_loss, avg_dsl_loss, avg_obj_loss))

            print(f"[{(epoch+1):>2}/{train_params['epochs']:>2}] - {step:>4}: Total Loss = {avg_total_loss:.4f}, DSL Loss = {avg_dsl_loss:.4f}, Obj Loss = {avg_obj_loss:.4f}")

            total_loss_sum = 0
            dsl_loss_sum = 0
            obj_loss_sum = 0

[ 1/ 1] -  100: Total Loss = 0.2404, DSL Loss = 0.2404, Obj Loss = 0.3253
[ 1/ 1] -  200: Total Loss = 0.2418, DSL Loss = 0.2418, Obj Loss = 0.3256
[ 1/ 1] -  300: Total Loss = 0.2414, DSL Loss = 0.2414, Obj Loss = 0.3279
[ 1/ 1] -  400: Total Loss = 0.2389, DSL Loss = 0.2389, Obj Loss = 0.3276
[ 1/ 1] -  500: Total Loss = 0.2400, DSL Loss = 0.2400, Obj Loss = 0.3197
[ 1/ 1] -  600: Total Loss = 0.2412, DSL Loss = 0.2412, Obj Loss = 0.3248
[ 1/ 1] -  700: Total Loss = 0.2435, DSL Loss = 0.2435, Obj Loss = 0.3263
[ 1/ 1] -  800: Total Loss = 0.2338, DSL Loss = 0.2338, Obj Loss = 0.3264
[ 1/ 1] -  900: Total Loss = 0.2373, DSL Loss = 0.2373, Obj Loss = 0.3316
[ 1/ 1] - 1000: Total Loss = 0.2394, DSL Loss = 0.2394, Obj Loss = 0.3227
[ 1/ 1] - 1100: Total Loss = 0.2342, DSL Loss = 0.2342, Obj Loss = 0.3243
[ 1/ 1] - 1200: Total Loss = 0.2402, DSL Loss = 0.2402, Obj Loss = 0.3179
[ 1/ 1] - 1300: Total Loss = 0.2350, DSL Loss = 0.2350, Obj Loss = 0.3268
[ 1/ 1] - 1400: Total Loss = 0.2451, D

In [8]:
print(dsl_cls)
print(F.softmax(dsl_cls, dim=-1))
print(obj_att)
pcfg_encoder.save_model()

tensor([-1.0925e-01, -1.7139e-01, -3.2324e-01,  6.4754e-04, -6.7822e-01,
        -2.5284e-02, -9.2712e-02,  2.6367e-01,  5.4785e-01], device='cuda:0',
       dtype=torch.float16, grad_fn=<SqueezeBackward4>)
tensor([0.1008, 0.0948, 0.0814, 0.1125, 0.0571, 0.1097, 0.1025, 0.1465, 0.1946],
       device='cuda:0', dtype=torch.float16, grad_fn=<SoftmaxBackward0>)
tensor([0.0370, 0.1039, 0.1280, 0.1023], device='cuda:0',
       grad_fn=<SliceBackward0>)
Model and embedding model saved to trained_pcfg_models\pcfg_encoder_20241203_165657.pth
