In [1]:
"""
The second half of this code takes in a bruker mk pt xml, extracts the existing trial templates, and replicates those 
trials according to a trial condition list provided.

Note the template needs to have an initial trigger after N frames dummy. For some reason Bruker sends an initial TTL trigger 
at the beginning of a session
"""

import numpy as np
import os
import csv
import xml.etree.ElementTree as etree

In [2]:
#name of session you're generating
xmlname = 'template' # basename of the session's xml
fname = 'fullSession' #name of session
fdir = r'Y:\charles_david\olfactometer_order'
counterbalance = 3

# path declaration
save_path = fdir
fname_arduino = fname + '_arduino.txt'
xml_fpath = os.path.join(fdir, fname + '.xml')
xml_savepath = os.path.join(fdir, fname + '_edited.xml')

#set parameters
trials1 = 30
trials2 = 150
trials3 = 30
stimTrials = 15 #gets iterated 10 times
order = np.zeros((3,), dtype=int)
for x in range(0,3):
    order[0] = counterbalance
    if counterbalance == 3:
        order[1] = 1
    else:                                                                                                                                                                              
        order[1] = counterbalance + 1
    if counterbalance >=2:
        order[2] = counterbalance - 1
    else:
        order[2] = 3

#print(order)
        
# path declaration
save_path = fdir
fname_arduinoRew = fname + '_arduinoRew.txt'
fname_arduinoStim = fname + '_arduinoStim.txt'
fname_olfactometer = fname + '_olfactometer.txt'
xml_fpath = os.path.join(fdir, xmlname + '.xml')
xml_savepath = os.path.join(fdir, fname + '.xml')

In [7]:
# loads the bruker mk pt xml, creates a dict of the template trial elements, and ultimately deletes those elements from the xml
def get_trial_types(root_xml):

    element_dict = {}
    for idx, element in enumerate(root_xml.findall('.//PVMarkPointElement')[1:]):

        key_name = 'trial_{}'.format(idx)
        element_dict[key_name] = element

        # get rid of existing trials; we will repopulate
        root_xml.remove(element)

    return element_dict


# adds trial elements based on vector of trial IDs supplied
def add_mk_pt_trials(trial_IDs, element_dict, root_xml):
    
    for trial_ID in stimGroup:
        trial_element = element_dict['trial_{}'.format(trial_ID)]

        root_xml.append(trial_element)
        
        
def make_mk_pt_xml_main(trial_IDs, xml_fpath, xml_savepath):
    # load and parse xml
    et = etree.parse(xml_fpath)
    xml_parse = et.getroot()

    element_dict = get_trial_types(xml_parse) # makes a dict of existing default trials in xml; then deletes those trial elements
    add_mk_pt_trials(trial_IDs, element_dict, xml_parse) # repopulate trials based on order in trial_IDs

    # Write back to file
    et.write(xml_savepath)

In [5]:
#generate stimulation trial order
stimGroup = np.zeros((stimTrials,), dtype=int)
stimGroup = [0,0,1,2,1,0,2,2,1,1,0,2,0,1,2]
stimGroup = np.tile(stimGroup, 10)
#numRepeats = np.zeros((stimTrials,), dtype=int)
#repeats = 1
#for x in range(0,stimTrials):
#    if repeats >= maxrepeat:
#       stimType[x] = 1 - stimType[x-1]
#        repeats = 1
#    else:
#        stimType[x] = np.random.randint(2, size=1)
#        if stimType[x] == stimType[x-1]:
#            repeats = repeats + 1
        
print(stimGroup)
        

[0 0 1 2 1 0 2 2 1 1 0 2 0 1 2 0 0 1 2 1 0 2 2 1 1 0 2 0 1 2 0 0 1 2 1 0 2
 2 1 1 0 2 0 1 2 0 0 1 2 1 0 2 2 1 1 0 2 0 1 2 0 0 1 2 1 0 2 2 1 1 0 2 0 1
 2 0 0 1 2 1 0 2 2 1 1 0 2 0 1 2 0 0 1 2 1 0 2 2 1 1 0 2 0 1 2 0 0 1 2 1 0
 2 2 1 1 0 2 0 1 2 0 0 1 2 1 0 2 2 1 1 0 2 0 1 2 0 0 1 2 1 0 2 2 1 1 0 2 0
 1 2]


In [8]:
make_mk_pt_xml_main(stimGroup, xml_fpath, xml_savepath)

In [66]:
#generate cue order
cueType = np.zeros((trials1+trials2+trials3,), dtype=int)
cueType[0:trials1]=order[0]
cueType[trials1:trials1+trials2]=order[1]
cueType[trials1+trials2:trials1+trials2+trials3]=order[2]
np.random.shuffle(cueType)
print(cueType)

#generate reward order
trialType = np.zeros((trials1+trials2+trials3,), dtype=int)
stimType = np.zeros((trials1+trials2+trials3,), dtype=int)
for x in range(0,trials1+trials2+trials3):
    if cueType[x] == order[0]:
        trialType[x] = 1
        stimType[x] = 0
    elif cueType[x] == order[1]:
        trialType[x] = np.random.randint(2, size=1)
        stimType[x] = 1
    elif cueType[x] == order[2]:
        trialType[x] = 0
        stimType[x] = 0

print(trialType)
print(stimType)

[2 1 3 1 3 1 1 1 1 1 3 1 2 1 1 1 1 1 1 2 1 1 2 1 3 1 1 1 3 1 1 2 3 1 3 1 2
 1 1 2 1 1 1 2 1 3 1 1 3 1 1 1 1 3 1 1 1 3 1 1 1 3 1 1 2 1 1 1 1 1 2 1 1 1
 1 1 1 1 1 1 1 2 1 1 1 1 1 3 1 3 1 1 1 1 1 1 1 1 2 1 1 1 1 3 2 1 1 1 3 1 2
 1 1 2 1 1 3 2 1 3 1 2 3 1 1 1 1 1 1 1 2 1 1 3 1 1 1 1 1 1 1 2 2 1 1 2 2 1
 1 1 1 1 1 1 3 3 1 3 1 1 1 2 1 2 2 1 1 1 2 1 1 3 1 1 1 1 3 1 3 3 1 1 3 1 1
 1 1 1 3 1 2 1 1 1 1 1 1 1 1 3 1 2 1 1 1 1 1 2 1 2]
[0 0 1 1 1 0 0 0 0 0 1 1 0 0 0 0 1 0 0 0 1 0 0 1 1 0 0 1 1 0 1 0 1 1 1 0 0
 1 0 0 0 0 1 0 1 1 1 0 1 0 1 1 1 1 0 0 1 1 1 1 0 1 0 0 0 1 1 1 1 1 0 0 0 0
 1 1 1 0 0 1 1 0 1 1 1 0 0 1 1 1 1 1 0 0 1 1 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0
 1 0 0 1 0 1 0 1 1 1 0 1 0 1 1 0 1 1 0 0 1 1 1 0 1 1 1 1 1 1 0 0 1 1 0 0 1
 1 1 1 0 0 1 1 1 1 1 0 1 1 0 0 0 0 0 0 0 0 1 0 1 0 1 1 1 1 0 1 1 0 1 1 1 1
 0 0 0 1 0 0 0 1 0 1 1 0 0 0 1 0 0 1 0 0 0 1 0 1 0]
[0 1 0 1 0 1 1 1 1 1 0 1 0 1 1 1 1 1 1 0 1 1 0 1 0 1 1 1 0 1 1 0 0 1 0 1 0
 1 1 0 1 1 1 0 1 0 1 1 0 1 1 1 1 0 1 1 1 0 1 1 1 0 1 1 0 1 1 1 1 1 0 1 

In [67]:
#create text files to copy into arduino code
os.chdir(save_path)
with open(fname_arduinoRew, 'w') as csvfile:
    csvfile = csv.writer(csvfile, delimiter=',')
    csvfile.writerow(trialType)
    
with open(fname_arduinoStim, 'w') as csvfile:
    csvfile = csv.writer(csvfile, delimiter=',')
    csvfile.writerow(stimType)   
    
#create text file to copy into olfactometer code
with open(fname_olfactometer, 'w') as olffile:
    olffile.write("206A Olfactometer Sequence File")
    olffile.write("\nVial\tDelivery (ms)\tDelay (s)")
    for x in range(0,trials1+trials2+trials3):
        olffile.write("\n" + str(cueType[x]) + "\tTrig\tEdge")

# Generate xml from scratch

In [12]:
import xml.etree.cElementTree as ET

root = ET.Element("PVSavedMarkPointSeriesElements", 
                  CalcFunctMap='False',
                  IterationDelay="0.00",
                  Iterations="1", 
                 )

ele_dummy_trial = ET.SubElement(root, 'PVMarkPointElement', parameterSet="CurrentSettings",
                                 VoltageRecCategoryName="None",
                                 VoltageOutputCategoryName="None",
                                 AsyncSyncFrequency="None",
                                 TriggerCount="1",
                                 TriggerSelection="PFI8",
                                 TriggerFrequency="FirstRepetition",
                                 UncagingLaserPower="0",
                                 UncagingLaser="Spirit 1040nm",
                                 Repetitions="1")

ET.SubElement(ele_dummy_trial, "PVGalvoPointElement", 
              Indices="1",
              Points="Point 1",
              AllPointsAtOnce="False",
              SpiralRevolutions="3",
              Duration="0.79",
              InterPointDelay="0.12",
              InitialDelay="0.12")

####

ele_trial = ET.SubElement(root, 'PVMarkPointElement', parameterSet="CurrentSettings",
                                 VoltageRecCategoryName="None",
                                 VoltageOutputCategoryName="None",
                                 AsyncSyncFrequency="None",
                                 TriggerCount="1",
                                 TriggerSelection="TrigIn",
                                 TriggerFrequency="FirstRepetition",
                                 UncagingLaserPower="0.6",
                                 UncagingLaser="Spirit 1040nm",
                                 Repetitions="15")

ET.SubElement(ele_trial, "PVGalvoPointElement", Indices="1-10",
              Points="Group 1",
              AllPointsAtOnce="True",
              SpiralRevolutions="3",
              Duration="10",
              InterPointDelay="56.67",
              InitialDelay="7.5")



tree = ET.ElementTree(root)
tree.write(os.path.join(fdir, fname + '_scratch.xml'))