|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 1:</h2>|<h1>Tokenizations and embeddings<h1>|
|<h2>Section:</h2>|<h1>Embedding spaces<h1>|
|<h2>Lecture:</h2>|<h1><b>Loss function to train the embeddings<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
# the figure in the slides
p = np.linspace(.0000001,.9,500)

plt.figure(figsize=(8,3))
plt.plot(p,p,label='p',linewidth=2)
plt.plot(p,np.log(p),label='log(p)',linewidth=2)
plt.gca().set(xlabel='Probability',ylabel='Loss value',xlim=[0,p[-1]])

plt.legend(fontsize=13)
plt.show()

In [None]:
# create a loss function instance
loss_function = nn.NLLLoss()
dir(loss_function)

In [None]:
# start with three outputs (raw model outputs for three tokens in the vocab)
model_output = torch.tensor([[ -1, 2.3, .1 ]],dtype=torch.float64)
print('Raw model outputs:')
print(f'  {model_output[0].tolist()}\n')

# NLLLoss expects log-softmax inputs!
logsoftmax_output = F.log_softmax(model_output,dim=-1)
print('Log-softmax model outputs:')
print('  ',[round(o.item(),2) for o in logsoftmax_output[0] ],'\n')


# check the loss for different targets
for target in range(len(model_output[0])):

  # which output is the target (correct response)?
  target = torch.tensor([target])

  # calculate the loss
  theloss = loss_function(logsoftmax_output,target)

  # and print
  print(f'When the correct output is index "{target.item()}", the loss is {theloss.item():.2f}')

# Multi-sample losses (for batches)

In [None]:
# let's create four batches
batch_output = model_output.repeat(4,1)
print(batch_output,'\n')

# but of course we need logsoftmax
logsoftmax_output = F.log_softmax(batch_output,dim=1)
print(logsoftmax_output)

In [None]:
# target changes for each batch
targets = torch.tensor([0,1,2,0])

In [None]:
loss = loss_function(logsoftmax_output,targets)
print(loss)

In [None]:
# again, manual calculation
-torch.tensor([ -3.4377 + -0.1377 + -2.3377 + -3.4377 ]) / len(targets)

In [None]:
loss.backward()

# Simple example in pytorch

In [None]:
# define a weight matrix (requires_grad=True to track gradients)
w = torch.tensor([[-1,.2]], requires_grad=True)

# target category
target = torch.tensor([0])

# optimizer
optimizer = torch.optim.SGD([w],lr=.5)

# training iterations
numTrainingIters = 10

# initialize some variables
allWeights = torch.zeros((numTrainingIters+1,2))
allWeights[0,:] = w.detach()
allLosses = torch.zeros(numTrainingIters)

# training loop
for i in range(numTrainingIters):

  # reset gradients
  optimizer.zero_grad()

  # model outputs (simulating a full model forward pass ;)  )
  modeloutput = F.log_softmax(w,dim=1)

  # loss
  loss = loss_function(modeloutput,target)
  allLosses[i] = loss.item()

  # gradient descent
  loss.backward()  # calculate gradient of loss wrt w
  optimizer.step() # adjust w using SGD

  # store the new weights
  allWeights[i+1,:] = w.detach()

  # and print out some results
  print(f"Step {i+1:2d}: loss = {loss.item():.3f}, weights = {[round(o.item(),2) for o in w[0] ]}")

In [None]:
# let's see the weights!
_,axs = plt.subplots(1,2,figsize=(10,3))


axs[0].plot(allLosses,'ks-',linewidth=1,markerfacecolor=[.9,.7,.7])
axs[0].set(xlabel='Training epochs',ylabel='Loss value',title='Losses during training')

axs[1].plot(allWeights[:,0],'ks-',markerfacecolor=[.7,.9,.7],linewidth=1,label='Weight 0 (target)')
axs[1].plot(allWeights[:,1],'ko-',markerfacecolor=[.7,.7,.9],linewidth=1,label='Weight 1 (non-target)')
axs[1].set(xlabel='Training epochs',ylabel='Weight value',title='Weight values')
axs[1].legend()

plt.tight_layout()
plt.show()