# Main Training Book for Our PCFG

In [1]:
# Import dependencies
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler

import pcfg_models as pcfg_models
import pcfg_loader as loader
import source as source


scaler = GradScaler()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

Device: cuda


  scaler = GradScaler()


In [2]:
loader_params = {
    'datafolder': 'training',
    'batch_size': 1,
    'shuffle': True,
    'only_inputs': True,
    'print_steps': False,
    'moves_per_step': 1,
    'max_steps': 1,
    'p_use_base': 0.1,
}

data_yielder = loader.get_pcfg_datayielder(**loader_params)

In [3]:
# Instantiate our Encoder classifier
dsl_list = data_yielder.label_ops
embedding_model_fp = 'vit_12-3-24_100k_v2.pth'
em = source.embedding.load_v2_ViT(embedding_model_fp, device='cpu')

pcfg_encoder = pcfg_models.PCFG_Encoder(
    n_embd = em.embed_dim, 
    n_head = 4, 
    n_layer = 6, 
    ff_hd = int(em.embed_dim * 2), 
    dropout = 0, # This may cause issues if you make this non-zero
    block_size = 35, 
    dsl_mapping = dsl_list, 
    embedding_model_fp = embedding_model_fp,
    freeze_emb_model=True, 
    device = device
)
pcfg_encoder = pcfg_encoder.to(device)
em = None   # Clear the model from memory

# Use this to load in a pre-trained model instead
# filename = ?
# pcfg_encoder = pcfg_models.PCFG_Encoder.load_model(f"trained_pcfg_models/{filename}", print_statements=True, device=device)

  checkpoint = torch.load(path, map_location=torch.device(device))


Vision Transformer instantiated with 117,472 parameters using Sinusoidal encodings.
Vision Transformer instantiated with 117,472 parameters using Sinusoidal encodings.
PCFG encoder instantiated with 52,288 parameters.


In [4]:
def embed_input(input, pcfg_encoder, use_grads, device):
    x = torch.zeros((len(input), pcfg_encoder.model_params['n_embd']), device=device)
    special_tokens = pcfg_encoder.get_special_tokens(device)   # cls_dsl, cls_obj, pad, sep
    
    try:
        first_pad = input[:16].index('<PAD>')
    except ValueError:
        first_pad = 16
    
    for i, obj in enumerate(input):
        if obj == "<PAD>":
            x[i, :] = special_tokens[2, :]
        elif obj == "<SEP>":
            x[i, :] = special_tokens[3, :]
        elif isinstance(obj, source.ARC_Object):
            # x[i, :] = torch.zeros((1, 64), device=device)
            obj.set_embedding(pcfg_encoder.embedding_model, use_grads=use_grads)
            x[i, :] = obj.embedding.to(device)
        else:
            raise NameError("Input contains object that is not '<PAD>', '<SEP>', or an ARC_Object.")
    x = torch.cat((special_tokens[:2, :], x), dim=0)
    return x, first_pad

In [5]:
def compute_kl_loss(dsl_cls, obj_att, dsl_label, obj_label, temp=0.2):
    """
    Compute multi-label multi-class loss using KL divergence.

    Args:
        dsl_cls: Tensor of size (len(dsl)) - predicted logits for DSL classes.
        obj_att: Tensor of size (len(obj_att)) - predicted logits for object attention.
        dsl_label: Tensor of integers up to len(dsl), representing ground truth labels for DSL classes.
        obj_label: Tensor of integers up to len(obj_att), representing ground truth labels for object attention.

    Returns:
        Total loss as a scalar tensor.
    """
    # One-hot encode labels for multi-label setup
    dsl_target = torch.zeros_like(dsl_cls, dtype=torch.float)
    dsl_target[dsl_label] = 1.0
    dsl_target = dsl_target / dsl_target.sum(dim=-1, keepdim=True)  # Normalize to probabilities
    obj_target = torch.zeros_like(obj_att, dtype=torch.float)
    obj_target[obj_label] = 1.0
    obj_target = obj_target / obj_target.sum(dim=-1, keepdim=True)  # Normalize to probabilities

    # Convert predictions to log-probabilities
    dsl_log_prob = F.log_softmax(dsl_cls, dim=-1)
    obj_log_prob = F.log_softmax(obj_att, dim=-1)

    # Compute KL divergence loss
    dsl_loss = F.kl_div(dsl_log_prob, dsl_target, reduction='batchmean')  # log_target=False by default
    obj_loss = F.kl_div(obj_log_prob, obj_target, reduction='batchmean')

    return dsl_loss, obj_loss

def compute_cross_entropy_loss(dsl_cls, obj_att, dsl_label, obj_label, max_loss=2.5):
    """
    Compute cross-entropy loss for DSL and object attention logits, with clipping.

    Args:
        dsl_cls: Tensor of size (len(dsl)) - predicted logits for DSL classes.
        obj_att: Tensor of size (len(obj_att)) - predicted logits for object attention.
        dsl_label: Single element tensor representing the ground truth label index for DSL classes.
        obj_label: Single element tensor representing the ground truth label index for object attention.
        max_loss: Maximum value to clip the loss (default: 3.0).

    Returns:
        Clipped cross-entropy losses for DSL and object attention as scalars.
    """
    # Compute cross-entropy loss directly
    dsl_loss = F.cross_entropy(dsl_cls.unsqueeze(0), dsl_label)
    obj_loss = F.cross_entropy(obj_att.unsqueeze(0), obj_label)

    # Clip the loss to a maximum value
    dsl_loss_clipped = torch.clamp(dsl_loss, max=max_loss)
    obj_loss_clipped = torch.clamp(obj_loss, max=max_loss)

    return dsl_loss_clipped, obj_loss_clipped

In [6]:
# Loop through data
train_params = {
    "epochs": 1,
    "lr": 1e-4,
    "print_frequency": 100,
}

optim = torch.optim.AdamW(pcfg_encoder.parameters(), lr=train_params['lr'])
avg_losses = []
dsl_cls = None
obj_att = None
itof = {v: k for k, v in dsl_list.items()}

for epoch in range(train_params['epochs']):
    active = True
    step, dsl_loss_sum, obj_loss_sum, total_loss_sum = 0, 0, 0, 0

    while active:
        try:
            key, input, label, obj_idx = next(data_yielder)
        except StopIteration:
            active = False
            break
        except loader.SampleExhausted:
            continue
        
        with torch.amp.autocast('cuda'):
            x, first_pad = embed_input(input, pcfg_encoder, use_grads=not pcfg_encoder.freeze_emb_model, device=device)
            dsl_cls, obj_att = pcfg_encoder(x)
            obj_att = obj_att[2:first_pad + 2]  # Since we have 2 special cls tokens at the start

            # Compute individual losses
            # dsl_loss, obj_loss = compute_kl_loss(
            #     dsl_cls, 
            #     obj_att, 
            #     torch.tensor(label, device=device), 
            #     torch.tensor(obj_idx, device=device)
            # )
            dsl_loss, obj_loss = compute_cross_entropy_loss(
                dsl_cls, 
                obj_att, 
                torch.tensor(label, device=device), 
                torch.tensor(obj_idx, device=device)
            )

            optim.zero_grad(set_to_none=True)
            # total_loss = dsl_loss + obj_loss
            total_loss = dsl_loss
            scaler.scale(total_loss).backward()
            scaler.step(optim)
            scaler.update()

        # Accumulate losses
        total_loss_sum += total_loss.item()
        dsl_loss_sum += dsl_loss.item()
        obj_loss_sum += obj_loss.item()
        step += 1

        if step % train_params['print_frequency'] == 0:
            avg_total_loss = total_loss_sum / train_params['print_frequency']
            avg_dsl_loss = dsl_loss_sum / train_params['print_frequency']
            avg_obj_loss = obj_loss_sum / train_params['print_frequency']

            avg_losses.append((avg_total_loss, avg_dsl_loss, avg_obj_loss))

            print(f"[{(epoch+1):>2}/{train_params['epochs']:>2}] - {step:>4}: Total Loss = {avg_total_loss:.4f}, DSL Loss = {avg_dsl_loss:.4f}, Obj Loss = {avg_obj_loss:.4f}")

            total_loss_sum = 0
            dsl_loss_sum = 0
            obj_loss_sum = 0

[ 1/ 1] -  100: Total Loss = 2.1754, DSL Loss = 2.1754, Obj Loss = 1.4404
[ 1/ 1] -  200: Total Loss = 2.1386, DSL Loss = 2.1386, Obj Loss = 1.5174
[ 1/ 1] -  300: Total Loss = 2.1451, DSL Loss = 2.1451, Obj Loss = 1.4750
[ 1/ 1] -  400: Total Loss = 2.1220, DSL Loss = 2.1220, Obj Loss = 1.5423
[ 1/ 1] -  500: Total Loss = 2.1291, DSL Loss = 2.1291, Obj Loss = 1.5546
[ 1/ 1] -  600: Total Loss = 2.1219, DSL Loss = 2.1219, Obj Loss = 1.6306
[ 1/ 1] -  700: Total Loss = 2.0601, DSL Loss = 2.0601, Obj Loss = 1.4783
[ 1/ 1] -  800: Total Loss = 2.0756, DSL Loss = 2.0756, Obj Loss = 1.5714
[ 1/ 1] -  900: Total Loss = 2.0711, DSL Loss = 2.0711, Obj Loss = 1.4920
[ 1/ 1] - 1000: Total Loss = 2.0955, DSL Loss = 2.0955, Obj Loss = 1.4747
[ 1/ 1] - 1100: Total Loss = 2.0084, DSL Loss = 2.0084, Obj Loss = 1.5276
[ 1/ 1] - 1200: Total Loss = 2.0471, DSL Loss = 2.0471, Obj Loss = 1.4302
[ 1/ 1] - 1300: Total Loss = 2.0184, DSL Loss = 2.0184, Obj Loss = 1.5320
[ 1/ 1] - 1400: Total Loss = 1.9928, D

In [7]:
print(dsl_cls)
print(F.softmax(dsl_cls, dim=-1))
print(obj_att)
pcfg_encoder.save_model()

tensor([ 0.4062, -0.0432, -1.4668,  0.1714, -2.6191, -0.8252,  0.2338,  0.6240,
         0.6084], device='cuda:0', dtype=torch.float16,
       grad_fn=<SqueezeBackward4>)
tensor([0.1605, 0.1024, 0.0247, 0.1268, 0.0078, 0.0468, 0.1350, 0.1995, 0.1964],
       device='cuda:0', dtype=torch.float16, grad_fn=<SoftmaxBackward0>)
tensor([ 0.0194,  0.0800, -0.0224,  0.0195,  0.0760,  0.0933], device='cuda:0',
       grad_fn=<SliceBackward0>)
Model and embedding model saved to trained_pcfg_models\pcfg_encoder_20241204_091144.pth
