# Transform Attentions

With confidence in our transformation process that we investigated in exploration/transform_prototyping, we process our first 2000 squad examples worth of attentions from exploration/extract_attentions.ipynb as 20 * 10GB binary pkl files.  This resulted in a 6.6GB CSV.

Transform Steps (for each file):

• Load attention binary  
• Scale attention values to 0-255  
• Reshape to (1, 3, 384, 384) tensor  
• Extract features using modified Barlow Twins  
• Flatten 12x12 representations  
• Convert tensors to dataframe columns  
• Append to dataframe  

Results are then output to representations_df.csv which was used in initial clustering exploration and scaling analysis.


In [None]:
import torch
import time
import os
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import torch.nn as nn
import pandas as pd

In [7]:
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler(feature_range=(0, 255)) # scaler 

In [8]:
data_dir='/tf/notebooks/QA_attentions_pickled'
output_dir='/tf/notebooks/QA_attentions_pickled/representations'

In [9]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x    

In [10]:
model = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')
model.fc = Identity() # pass through values from second to last layer, bypassing linear classifier

Using cache found in /root/.cache/torch/hub/facebookresearch_barlowtwins_main


In [11]:
cuda = torch.device('cuda:0')
model.to(cuda)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [12]:
def plot_layer_heads(attention, num_layers=12, num_heads=12):
  for i in range(0,11):
    p = attention[i]
    fig, axis = plt.subplots(1,12, figsize=(20,5), facecolor='w', edgecolor='k')
    plt.title(f'layer {i}')
    head = 0
    for axs, ph in zip(axis.flatten(), p):
      heatmap = axs.imshow(ph, cmap='hot')

In [13]:
def scale_examples(examples):
    scaled_examples = np.empty(shape=(100), dtype=np.ndarray)
    for i, example in enumerate(examples):
        new_example = np.empty(shape=(12,12), dtype=np.ndarray)
        for l, layer in enumerate(example): #12 layers
            new_layer = np.array([])
            for h, head in enumerate(layer): #12 heads
                flat_head_transformed = scaler.fit_transform(head)
                new_example[l,h] = flat_head_transformed.reshape(384,384)
        scaled_examples[i] = new_example
    return scaled_examples

In [14]:
# create (1, 3, 384, 384) shape expected by barlow twins model
def reshape_example(image):
    example_channel = np.expand_dims(image, axis=0)
    batch = np.append(example_channel, example_channel, axis=0)
    batch = np.append(batch, example_channel, axis=0)
    example_3channel = np.expand_dims(batch, axis=0)
    return example_3channel

In [15]:
def get_representations(attentions):
    barlow_representations = np.zeros((100), np.object)
    for i, example in enumerate(attentions):
        reshaped_example = np.zeros((12,12), np.object)
        for l, layer in enumerate(example):
            for h, head in enumerate(layer):
                reshaped_head = torch.from_numpy(reshape_example(head)).to(cuda)
                representation_head = model(reshaped_head.float())
                reshaped_example[l][h] = representation_head.detach().cpu().numpy()
        barlow_representations[i] = reshaped_example

    return barlow_representations

In [16]:
def flatten_layer_heads(representations_tensor):
    print("flattening layers/heads ...")
    flat_array = np.zeros((14400), np.ndarray)
    i = 0
    for example in representations_tensor:
        for layer in example:
            for h, head in enumerate(layer):
                flat_array[i] = head[0]
                i += 1
    return flat_array

In [18]:
batch_size = 100
batch_num = 0
representation_df = pd.DataFrame()
representation_array = []
for i in range(1,21):
    start_time = time.time()
    batch_num = i * batch_size
    print(f"Loading attentions batch {batch_num}")
    attentions = torch.load(os.path.join(data_dir, f"eval_attentions_{batch_num}.bin"))
    print("Scaling attention values to 0-255 ...")
    scaled_attentions =  scale_examples(attentions)
    print("Processing to 2048 value representations through barlow_twins ...")
    barlow_representations = get_representations(scaled_attentions)
    print("Appending results to array/dataframe ...")
    flat_representations = flatten_layer_heads(barlow_representations)
    representation_array.append(flat_representations)
    df = pd.DataFrame(flat_representations)
    df = pd.DataFrame([pd.Series(x) for x in df[0]])
    representation_df = representation_df.append(df, ignore_index=True)
    print(f"--- eval to representation batch {batch_num} in  {(time.time() - start_time)} seconds ---")

Loading attentions batch 100
Scaling attention values to 0-255 ...
Processing to 2048 value representations through barlow_twins ...


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Appending results to array/dataframe ...
flattening layers/heads ...
--- eval to representation batch 100 in  217.97882652282715 seconds ---
Loading attentions batch 200
Scaling attention values to 0-255 ...
Processing to 2048 value representations through barlow_twins ...
Appending results to array/dataframe ...
flattening layers/heads ...
--- eval to representation batch 200 in  226.08006763458252 seconds ---
Loading attentions batch 300
Scaling attention values to 0-255 ...
Processing to 2048 value representations through barlow_twins ...
Appending results to array/dataframe ...
flattening layers/heads ...
--- eval to representation batch 300 in  224.35632610321045 seconds ---
Loading attentions batch 400
Scaling attention values to 0-255 ...
Processing to 2048 value representations through barlow_twins ...
Appending results to array/dataframe ...
flattening layers/heads ...
--- eval to representation batch 400 in  227.2799997329712 seconds ---
Loading attentions batch 500
Scaling a

In [19]:
torch.save(representation_array,os.path.join(data_dir, f"representation_array.bin"))

In [20]:
torch.save(representation_df,os.path.join(data_dir, f"representation_df.bin"))

In [21]:
representation_df.to_csv(os.path.join(data_dir, f"representation_df.csv"))

In [25]:
representation_df.head(1)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
0,0.001772,0.021541,0.001335,0.022821,0.024328,0.0,0.032967,0.122129,0.014402,0.016769,...,0.057488,0.025694,0.012562,0.095822,0.0,0.02446,0.003141,0.02578,0.00421,0.00793


In [None]:
df = pd.DataFrame(flat_representations)

In [41]:
df = pd.DataFrame([pd.Series(x) for x in df[0]])

In [46]:
representation_df = representation_df.append(df, ignore_index=True).head(1)

In [47]:
representation_df.head(1)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
0,0.001772,0.021541,0.001335,0.022821,0.024328,0.0,0.032967,0.122129,0.014402,0.016769,...,0.057488,0.025694,0.012562,0.095822,0.0,0.02446,0.003141,0.02578,0.00421,0.00793


In [28]:
representation_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 288000 entries, 0 to 287999
Columns: 2048 entries, 0 to 2047
dtypes: float32(2048)
memory usage: 2.2 GB


In [29]:
representation_df.describe

<bound method NDFrame.describe of             0         1         2         3         4         5         6     \
0       0.001772  0.021541  0.001335  0.022821  0.024328  0.000000  0.032967   
1       0.000670  0.007020  0.084146  0.023934  0.032102  0.000000  0.030618   
2       0.009630  0.009985  0.004485  0.030183  0.045801  0.002238  0.050636   
3       0.002628  0.004462  0.012899  0.028588  0.077253  0.000000  0.007670   
4       0.001193  0.077093  0.037766  0.050492  0.005282  0.003195  0.038456   
...          ...       ...       ...       ...       ...       ...       ...   
287995  0.000908  0.006504  0.013558  0.011609  0.000802  0.007187  0.044382   
287996  0.000000  0.008056  0.052457  0.006509  0.014854  0.000809  0.022595   
287997  0.000274  0.014337  0.031389  0.016458  0.006799  0.000000  0.005014   
287998  0.000066  0.013728  0.034654  0.015540  0.001389  0.000000  0.009813   
287999  0.000000  0.000373  0.020018  0.014481  0.000000  0.018589  0.005801   

     