In [1]:
import argparse
import itertools
import os.path
import time
import sys
sys.path.append(os.getcwd() + '/Code/')

import torch
import torch.optim.lr_scheduler
import torch.nn as nn
import numpy as np
from pathlib import Path

import evaluate
import trees
import vocabulary
import nkutil
import parse_nk
import csv
import matplotlib.pyplot as plt
import random
import pdb

import pandas as pd
import seaborn as sns

import warnings
warnings.filterwarnings('ignore')

Not using CUDA!


In [2]:
# Load the parser

temp = torch.load(str(Path(os.getcwd())) + "/Data/en_charlstm_dev.93.61.pt")
parser = parse_nk.NKChartParser.from_spec(temp['spec'], temp['state_dict'])
if 'UNK' in parser.tag_vocab.indices:
    dummy_tag = 'UNK'
else:
    dummy_tag = parser.tag_vocab.value(0)

In [3]:
# Routine to get activations from intermediate layers

def process_sequences_perlayer(seq_batch):
    outputs_0= []
    def hook_0(module, input, output):
        outputs_0.append(output)
    outputs_1 = []
    def hook_1(module, input, output):
        outputs_1.append(output)
    outputs_2 = []
    def hook_2(module, input, output):
        outputs_2.append(output)
    outputs_3 = []
    def hook_3(module, input, output):
        outputs_3.append(output)
    outputs_4 = []
    def hook_4(module, input, output):
        outputs_4.append(output)
    outputs_5 = []
    def hook_5(module, input, output):
        outputs_5.append(output)
    outputs_6 = []
    def hook_6(module, input, output):
        outputs_6.append(output)
    outputs_7 = []
    def hook_7(module, input, output):
        outputs_7.append(output)

    parser.encoder.ff_0.relu.register_forward_hook(hook_0)
    parser.encoder.ff_1.relu.register_forward_hook(hook_1)
    parser.encoder.ff_2.relu.register_forward_hook(hook_2)
    parser.encoder.ff_3.relu.register_forward_hook(hook_3)
    parser.encoder.ff_4.relu.register_forward_hook(hook_4)
    parser.encoder.ff_5.relu.register_forward_hook(hook_5)
    parser.encoder.ff_6.relu.register_forward_hook(hook_6)
    parser.encoder.ff_7.relu.register_forward_hook(hook_7)

    out, _ = parser.parse_batch(seq_batch)
    del _

    temp_acts = np.array([outputs_0[0].numpy()[1:-1,:],
                 outputs_1[0].numpy()[1:-1,:], 
                 outputs_2[0].numpy()[1:-1,:],
                 outputs_3[0].numpy()[1:-1,:],
                 outputs_4[0].numpy()[1:-1,:],
                 outputs_5[0].numpy()[1:-1,:],
                 outputs_6[0].numpy()[1:-1,:],
                 outputs_7[0].numpy()[1:-1,:]])
    return temp_acts


In [7]:
def create_sentences(SenType = 'sen'):
    # possible sentypes: 
    # sen:          regular sentence
    # wordlist:     wordlists
    # nongrammar:   grammatically correct, nonsensical
    # nounphrase:   noun phrases
    # verbphrase:   verb phrases
    # random:       adjective noun verb noun random
    
    # Load dictionaries
    base = str(Path(os.getcwd()))
    
    with open(base + '/Data/Ding_grammatical.csv', 'r') as f:
        all_sentences = [row for row in csv.reader(f)]
    nouns, adjectives, verbs = [],[],[]
    for sentence in all_sentences:
        nouns.append(sentence[1])
        nouns.append(sentence[3])
        adjectives.append(sentence[0])
        verbs.append(sentence[2])
    nouns = list(set(nouns))
    adjectives = list(set(adjectives))
    verbs = list(set(verbs))
    all_words = nouns + adjectives + verbs


    N = 60        
    if SenType == 'sen':
        # Load words as lists
        with open(base + '/Data/Ding_grammatical.csv', 'r') as f:
            sentence_list = [row for row in csv.reader(f)]
            
    elif SenType == 'wordlist':
        sentence_list = [random.sample(all_words,4) for i in range(N)]
        
    elif SenType == 'nongrammar':   
        sentence_list = list()
        for i in range(N):
            w1 = random.sample(adjectives,1)[0]
            w2 = random.sample(nouns,1)[0]
            w3 = random.sample(verbs,1)[0]
            w4 = random.sample(nouns,1)[0]
            sentence_list.append([w1,w2,w3,w4])
            
    elif SenType == 'nounphrase':        
        sentence_list = list()
        for i in range(N):
            w1 = random.sample(adjectives,1)[0]
            w2 = random.sample(nouns,1)[0]
            w3 = random.sample(adjectives,1)[0]
            w4 = random.sample(nouns,1)[0]
            sentence_list.append([w1,w2,w3,w4])
            
    elif SenType == 'verbphrase':
        sentence_list = list()
        for i in range(N):
            w1 = random.sample(verbs,1)[0]
            w2 = random.sample(nouns,1)[0]
            w3 = random.sample(verbs,1)[0]
            w4 = random.sample(nouns,1)[0]
            sentence_list.append([w1,w2,w3,w4])
            
    elif SenType == 'random':
        sentence_list = list()
        for i in range(N):
            w1 = random.sample(verbs,1)[0]
            w2 = random.sample(nouns,1)[0]
            w3 = random.sample(nouns,1)[0]
            w4 = random.sample(adjectives,1)[0]
            sentence_list.append([w1,w2,w3,w4])   
    return sentence_list#, embedding_dict, one_hot_dict

In [5]:
def gen_data_matrix(seq_type = 'sen'):
    all_sequences = create_sentences(SenType=seq_type)
    random.shuffle(all_sequences)
    all_samples = []
    for i in range(20):
        current_sample = random.sample(all_sequences,13)
        current_sample = [word for seq in current_sample for word in seq]
        all_samples.append(current_sample)
    return all_samples

In [11]:
# Run simulations for all conditions
saveoutput = 0
seq_types = ['sen','wordlist','nongrammar','nounphrase','verbphrase','random']

for i,seq_type in enumerate(seq_types):
    seq_input = gen_data_matrix(seq_type=seq_type)
    mean_powers = [[] for j in range(8)] # number of feed forward layers considered
    all_activations = []
    for seq in seq_input:
        subbatch_sentences = [[(dummy_tag, word) for word in seq]]
        activations = process_sequences_perlayer(subbatch_sentences)
        all_activations.append(activations)
    if saveoutput:
        np.save(str(Path(os.getcwd())) +  '/Data/'
                 seq_type, np.array(all_activations))